# Data Processing


In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import gc

class CustomDataset(Dataset):
    def __init__(self, file_path1, file_path2, file_path3, file_path4):
        """
        Dataset class for the model

        Args:
            data_paths (str): Path to .pt files
        """
        super().__init__()
        p1 = torch.load(file_path1)
        p2 = torch.load(file_path2)
        p3 = torch.load(file_path3)
        p4 = torch.load(file_path4)

        self.data = p1 + p2 + p3 + p4

        del p1, p2, p3, p4
        gc.collect()

        self.pad = nn.ReplicationPad2d(4)
        self.crop_size = 84

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        """
        Returns a sample from the dataset.

        Args:
            idx (int): Index of the sample.

        Returns:
            dict: A dictionary containing:
                - "seq_canonical" (torch.Tensor): Concatenated canonical sequence of shape [9, H, W],
                where the temporal dimension (3 frames) is merged into the channel dimension (3 channels * 3 frames = 9).
                - "seq_random" (torch.Tensor): Concatenated randomized sequence of shape [9, H, W],
                where the temporal dimension (3 frames) is merged into the channel dimension.
                - "action" (torch.Tensor): Action corresponding to the last frame of the sequence (scalar or vector depending on the environment).
                - "target_canon" (torch.Tensor): Future canonical sequence including the last two frames of the current sequence
                and the predicted future frame, of shape [9, H, W].
        """
        sample = self.data[idx]

        seq_canonical = sample["canonical"]
        t_2 = self._random_crop(self.pad(torch.from_numpy(seq_canonical[0]).permute(2,0,1).float()))
        t_1 = self._random_crop(self.pad(torch.from_numpy(seq_canonical[1]).permute(2,0,1).float()))
        t_0 = self._random_crop(self.pad(torch.from_numpy(seq_canonical[2]).permute(2,0,1).float()))
        seq_canonical = torch.cat([t_2,t_1,t_0], dim=0)

        seq_random = sample["randomized"]
        seq_random = torch.cat([
              self._random_crop(self.pad(torch.from_numpy(seq_random[0]).permute(2,0,1).float())),
              self._random_crop(self.pad(torch.from_numpy(seq_random[1]).permute(2,0,1).float())),
              self._random_crop(self.pad(torch.from_numpy(seq_random[2]).permute(2,0,1).float()))
          ], dim=0)

        action_0 = torch.tensor(sample["actions"][2], dtype=torch.float32)

        target_canon = sample["future_canon"]
        seq_next_canon = torch.cat([
            t_1,
            t_0,
            self._random_crop(self.pad(torch.from_numpy(target_canon).permute(2,0,1).float()))

        ], dim=0)

        return {
            "seq_canonical": seq_canonical,   # [9 x H x W]
            "seq_random": seq_random,         # [9 x H x W]
            "action": action_0,               # [action_dim]
            "target_canon": seq_next_canon    # [9 x H x W]
        }

    def _random_crop(self, padded):
        """
        Performs a random crop on a tensor with padding.

        Args:
            padded (torch.Tensor): Tensor with padding.

        Returns:
            torch.Tensor: Cropped tensor with the desired size.
        """
        _, channels, height, width = padded.unsqueeze(0).shape  # Aggiunge dimensione batch
        crop_x = torch.randint(0, height - self.crop_size + 1, (1,)).item()
        crop_y = torch.randint(0, width - self.crop_size + 1, (1,)).item()

        cropped = padded[:, crop_x:crop_x + self.crop_size, crop_y:crop_y + self.crop_size]
        return cropped

In [2]:
def create_dataloader(file_path1, file_path2, file_path3, file_path4, batch_size=128, shuffle=True, num_workers=4):
    """
    Creates a DataLoader for the model.

    Args:
        file_path1 (str): Path to the first .pt file containing data.
        file_path2 (str): Path to the second .pt file containing data.
        file_path3 (str): Path to the third .pt file containing data.
        file_path4 (str): Path to the fourth .pt file containing data.
        batch_size (int): Size of each batch.
        shuffle (bool): Whether to shuffle the data.
        num_workers (int): Number of parallel processes for data loading.

    Returns:
        DataLoader: A PyTorch DataLoader instance.
    """
    dataset = CustomDataset(file_path1, file_path2, file_path3, file_path4)
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=True
    )
    return dataloader

In [None]:
import torch
from google.colab import drive
from pathlib import Path
drive.mount('/content/drive')

#Path to store file on Google Drive
#output_dir = "/content/drive/My Drive/Reinforcement Learning/Reinforcement Learning/DRAP"
output_dir = "/content/drive/My Drive/Reinforcement Learning/DRAP"

# Carica il file .pt
#file_path1 = output_dir + "/dataset/walker/dataset_sequences_1.pt"
#file_path2 = output_dir + "/dataset/walker/dataset_sequences_2.pt"
#file_path3 = output_dir + "/dataset/walker/dataset_sequences_3.pt"
#file_path4 = output_dir + "/dataset/walker/dataset_sequences_4.pt"
file_path1 = output_dir + "/dataset/walker/dataset_sequences_1.pt"
file_path2 = output_dir + "/dataset/walker/dataset_sequences_2.pt"
file_path3 = output_dir + "/dataset/walker/dataset_sequences_3.pt"
file_path4 = output_dir + "/dataset/walker/dataset_sequences_4.pt"

dataloader = create_dataloader(file_path1, file_path2, file_path3, file_path4, batch_size=128)

In [4]:
print(len(dataloader.dataset))

25000


In [None]:
for batch in dataloader:
      seq_canonical = batch["seq_canonical"]
      seq_random = batch["seq_random"]
      action = batch["action"]
      seq_next_canon = batch["target_canon"]

      print(f" Sequences lenght: {seq_canonical.shape}, {seq_random.shape}. Action shape: {action.shape}. Target canon shape: {seq_next_canon.shape}")


# Auto-Encoder

In [4]:
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

In [5]:
class Encoder(nn.Module):
    """Convolutional encoder for image-based observations. Same ad DrQ"""
    def __init__(self, obs_shape, n_features, device='cuda'):
        super().__init__()
        self.device = device
        self.n_features = n_features
        self.img_channels = obs_shape[0]
        self.n_filters = 32

        self.conv1 = nn.Conv2d(self.img_channels, self.n_filters, 3, stride=2)
        self.conv2 = nn.Conv2d(self.n_filters, self.n_filters, 3, stride=1)
        self.conv3 = nn.Conv2d(self.n_filters, self.n_filters, 3, stride=1)
        self.conv4 = nn.Conv2d(self.n_filters, self.n_filters, 3, stride=1)

        self.fc = nn.Linear(35 * 35 * self.n_filters, self.n_features)
        self.norm = nn.LayerNorm(self.n_features)

    def forward(self, obs, detach=False):
        obs = obs / 255.0  
        self.conv1_output = F.relu(self.conv1(obs))
        self.conv2_output = F.relu(self.conv2(self.conv1_output))
        self.conv3_output = F.relu(self.conv3(self.conv2_output))
        self.conv4_output = F.relu(self.conv4(self.conv3_output))

        x = self.conv4_output.reshape(self.conv4_output.size(0), -1)

        if detach:
            x = x.detach()

        self.fc_output = self.fc(x)
        self.norm_output = self.norm(self.fc_output)

        out = torch.tanh(self.norm_output)
        return out

In [6]:
class Decoder(nn.Module):
    """Decoder ro reconstruct the images from the embedding."""
    def __init__(self, n_features, img_channels, device='cuda'):
        super().__init__()
        self.device = device
        self.n_features = n_features
        self.img_channels = img_channels
        self.n_filters = 32

        #Transform the embedding in a spatial tensor
        self.fc1 = nn.Linear(n_features, self.n_filters)
        self.fc2 = nn.Linear(self.n_filters, 42 * 42 * self.n_filters)

        # Upconvolution
        self.upconv = nn.ConvTranspose2d(self.n_filters, self.n_filters, kernel_size=4, stride=2, padding=1)

        #Convolutions
        self.conv1 = nn.Conv2d(self.n_filters, self.n_filters, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(self.n_filters, img_channels, kernel_size=3, stride=1, padding=1)

    def forward(self, z):
        x = F.relu(self.fc1(z))
        x = F.relu(self.fc2(x))
        #print(f"post fcs: {x.shape}")
        x = x.view(-1, self.n_filters, 42, 42)  
        #print(f"post view: {x.shape}")
        x = F.relu(self.upconv(x))
        #print(f"post upconv: {x.shape}")
        x = F.relu(self.conv1(x))
        #print(f"post conv1: {x.shape}")
        x = torch.sigmoid(self.conv2(x)) 
        #print(f"post conv2: {x.shape}")
        return x

In [7]:
class MLP(nn.Module):
    """
    MLP to combine embedding and action.
    It predicts the future embedding conditioned by the action.
    """
    def __init__(self, n_features, action_dim):
        super().__init__()
        self.fc1 = nn.Linear(n_features + action_dim, 1024)
        self.fc2 = nn.Linear(1024, n_features)

    def forward(self, z, action):
        #print(z.shape)
        #print(action.shape)
        x = torch.cat([z, action], dim=-1) 
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        #print(x.shape)
        return x

In [8]:
class DRAPModel(nn.Module):
    """Domain Randomization Removal Pre-training (DRAP) Model."""
    def __init__(self, obs_shape, n_features, action_dim, device = torch.device("cuda")):
        super().__init__()
        self.device = device
        self.encoder = Encoder(obs_shape, n_features, device).to(self.device)
        self.mlp = MLP(n_features, action_dim).to(self.device)
        self.decoder = Decoder(n_features, obs_shape[0], device).to(self.device)

    def forward(self, obs_stack, action):
            """
            Args:
                obs_stack: Stack of 3 randomized observations (shape: [batch_size, 9, H, W]).
                action: lat action t0 (shape: [batch_size, action_dim]).
            Returns:
                recon_curr: Reconstruction of the sequence (t-2, t-1, t0).
                recon_next: Reconstructio of future sequence (t-1, t0, t+1).
            """
            z = self.encoder(obs_stack)  # Bottleneck embedding
            #print(f"z= {z.shape}")

            recon_stack = self.decoder(z)
            #print(f"recon_stack= {recon_stack.shape}")

            z_future = self.mlp(z, action)
            #print(f"z_future= {z_future.shape}")

            recon_next = self.decoder(z_future)

            return recon_stack, recon_next

    def _train(self, dataloader, epochs=1, lr=3e-4):
        optimizer = torch.optim.Adam(self.parameters(), lr=lr)
        criterion = nn.MSELoss()


        for epoch in range(epochs):
            self.train()
            train_loss = 0.0

            with tqdm(total=len(dataloader), desc=f"Epoch {epoch + 1}/{epochs}", unit="batch") as pbar:
              for batch in dataloader:
                  seq_canonical = batch["seq_canonical"]
                  seq_random = batch["seq_random"]
                  action = batch["action"]
                  seq_next_canon = batch["target_canon"]

                  seq_random = seq_random.to(self.device)
                  seq_canonical = seq_canonical.to(self.device)
                  action = action.to(self.device)
                  seq_next_canon = seq_next_canon.to(self.device)

                  optimizer.zero_grad()
                  recon_stack, recon_next = self.forward(seq_random, action)
                  loss = criterion(recon_stack, seq_canonical) + criterion(recon_next, seq_next_canon)
                  loss.backward()
                  optimizer.step()
                  train_loss += loss.item()

                  pbar.set_postfix(loss=loss.item())
                  pbar.update(1)

            avg_train_loss = train_loss / len(dataloader.dataset)
            if epoch % 1 == 0:
             print(f"Epoch {epoch + 1}/{epochs}: Avg Training Loss: {avg_train_loss:.4f}")

    def save_encoder_weights(self, file_path):
        """
       Save encoder parameters on a .pt file.

        Args:
            model: DRAP model.
            file_path: Path to save the .pt file.
        """
        encoder_weights = self.encoder.state_dict()

        torch.save(encoder_weights, file_path)
        print(f'Encoder parameters saved in {file_path}')


In [None]:
model = DRAPModel((9, 84, 88),50,6, device='cuda')
model._train(dataloader,1)
#model.save_encoder_weights(output_dir + "/encoder_weights_finger.pt")
model.save_encoder_weights(output_dir + "/encoder_weights_walker.pt")