---
## 1. Library Imports
In this section, we import the necessary libraries and packages that will be used throughout the notebook. This includes the PyTorch library for deep learning, torchvision for computer vision tasks, and other utilities.

---

In [1]:
import torch
from torchvision import models
from torch import nn
from data_aug.contrastive_learning_dataset import ContrastiveLearningDataset
from tqdm import tqdm
from utils import accuracy
import os
import wandb

## 2. Model Definition and Helper Functions

Here, we define helper functions and the main model architecture. The `adapt_state_dict` function adapts the state dictionary for loading pretrained models. The `LinearClassifier` class is the main model that uses ResNet-based encoders for both RGB and tactile inputs. Pretrained weights can be loaded into this model for transfer learning.

---

In [2]:
def adapt_state_dict(state_dict):
    """
    Adapts the state dictionary's key names to match the expected keys of the ResNet model.
    """
    adapted_state_dict = {}
    for k, v in state_dict.items():
        # Remove the prefixed numbers from the key names
        new_key = '.'.join(k.split('.')[1:])
        adapted_state_dict[new_key] = v
    return adapted_state_dict

class LinearClassifier(nn.Module):
    def __init__(self, num_classes, checkpoint_path, nn_model='resnet18', pretrained=True):
        super(LinearClassifier, self).__init__()
        self.nn_model = nn_model
        self.rgb_encoder = self.create_resnet_encoder(3)
        self.tactile_encoder = self.create_resnet_encoder(6)
        
        if pretrained:
            # Load the checkpoint
            checkpoint = torch.load(checkpoint_path)
            
            # Adapt the state dictionary key names
            adapted_rgb_state_dict = adapt_state_dict(checkpoint['state_dict_vis'])
            adapted_tactile_state_dict = adapt_state_dict(checkpoint['state_dict_tac'])
            
            # Load the state dict for the visual and tactile encoders
            self.rgb_encoder.load_state_dict(adapted_rgb_state_dict, strict=False)
            self.tactile_encoder.load_state_dict(adapted_tactile_state_dict, strict=False)
            
            # Freeze the weights of the encoders
            for param in self.rgb_encoder.parameters():
                param.requires_grad = False
            for param in self.tactile_encoder.parameters():
                param.requires_grad = False
        # Assuming the output features of both encoders are of size 512 (e.g., for ResNet-18)
        # Adjust this if the size is different
        self.linear_layer = nn.Linear(512 * 2, num_classes)
    
    def create_resnet_encoder(self, n_channels):
        """Create a ResNet encoder based on the specified model type."""
        if self.nn_model == 'resnet18':
            resnet = models.resnet18(pretrained=False)
        elif self.nn_model == 'resnet50':
            resnet = models.resnet50(pretrained=False)
        if n_channels != 3:
            resnet.conv1 = nn.Conv2d(n_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        features = list(resnet.children())[:-2]  # Exclude the avgpool and fc layers
        features.append(nn.AdaptiveAvgPool2d((1, 1)))
        features.append(nn.Flatten())
        return nn.Sequential(*features)

    def forward(self, rgb_input, tactile_input):
        rgb_features = self.rgb_encoder(rgb_input)
        tactile_features = self.tactile_encoder(tactile_input)
        
        # Concatenate the features from both encoders
        combined_features = torch.cat((rgb_features, tactile_features), dim=1)
        
        return self.linear_layer(combined_features)


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint_path = 'runs/Oct04_09-28-47_cpsadmin-Z790-AORUS-ELITE-AX/model_700_best_object_wise.pth'
print(f"Using device: {device}")
linear_classifier = LinearClassifier(num_classes=10, checkpoint_path=checkpoint_path, nn_model='resnet18', pretrained=True)
linear_classifier = linear_classifier.to(device)

Using device: cuda




## 3. Dataset Preparation
This section prepares the datasets for training and evaluation. Data augmentation techniques might be applied to increase the variability in the training data. We also define the data loaders, which will provide batches of data during training and evaluation.

---

In [4]:
batch_size = 512
num_workers = 8
use_wandb = True
dataset = ContrastiveLearningDataset(root_folder='calandra_objects_split_object_wise')
train_dataset = dataset.get_dataset('calandra_label_train', 2)
test_dataset = dataset.get_dataset('calandra_label_test', 2,)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                                           num_workers=num_workers, drop_last=False, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False,  
                                            num_workers=num_workers, drop_last=False, pin_memory=True)

## 4. Optimizer and Loss Function
Before training, we need to define how the model will be optimized and what loss function will measure the model's performance. Here, we initialize the Adam optimizer and the cross-entropy loss function.

---

In [5]:
optimizer = torch.optim.Adam(linear_classifier.parameters(), lr=3e-4, weight_decay=0.0008)
criterion = torch.nn.CrossEntropyLoss().to(device)

## 5. Training and Evaluation Loop
The core of the notebook. In this section, we train the model on the training data and evaluate its performance on the test data. The model's weights are updated in each epoch based on the optimizer and loss function. After training, we evaluate the model's performance on unseen data to gauge its generalization capability.

---

In [6]:
epochs = 20
if use_wandb:
    # init wandb
    wandb.init(project="calandra_object_wise_linear_classifier", name="resnet18_pretrained")
    if not os.path.exists(f"runs/{wandb.run.name}"):
        os.makedirs(f"runs/{wandb.run.name}")
    subfolder = wandb.run.name
else:
    subfolder = "test"

[34m[1mwandb[0m: Currently logged in as: [33mligerfotis[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [8]:
best_train_accuracy = 0
for epoch in range(epochs):
    top1_train_accuracy = 0
    epoch_loss = 0
    linear_classifier.train()
    pbar = tqdm(train_loader)
    for counter, data in enumerate(pbar):
        rgb_image_q, _, stacked_gelsight_images_q, _, label = data
        
        rgb_image_q = rgb_image_q.to(device)
        stacked_gelsight_images_q = stacked_gelsight_images_q.to(device)
        label = label.to(device)
        
        logits = linear_classifier(rgb_image_q, stacked_gelsight_images_q)
        loss = criterion(logits, label)
        epoch_loss += loss.item()
        top1 = accuracy(logits, label, topk=(1,))
        top1_train_accuracy += top1[0]
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # update the progress bar message
        pbar.set_description(f"Epoch {epoch}: Loss: {epoch_loss / (counter + 1):.2f}\tTrain Accuracy: {top1_train_accuracy.item() / (counter + 1):.2f}")
    epoch_loss /= (len(train_loader))
    top1_train_accuracy /= (len(train_loader))
    # save the model
    if top1_train_accuracy > best_train_accuracy:
        torch.save(linear_classifier.state_dict(), f"runs/{subfolder}/linear_classifier_{epoch}_best_object_wise.pth")
        best_train_accuracy = top1_train_accuracy
        
    top1_accuracy = 0
    top5_accuracy = 0
    linear_classifier.eval()
    pbar = tqdm(test_loader)
    with torch.no_grad():
        for counter, data in enumerate(pbar):
            rgb_image_q, _, stacked_gelsight_images_q, _, label = data
            
            rgb_image_q = rgb_image_q.to(device)
            stacked_gelsight_images_q = stacked_gelsight_images_q.to(device)
            label = label.to(device)
            
            logits = linear_classifier(rgb_image_q, stacked_gelsight_images_q)
            
            top1, top5 = accuracy(logits, label, topk=(1,5))
            top1_accuracy += top1[0]
            top5_accuracy += top5[0]
            
            # update the progress bar message
            pbar.set_description(f"Epoch {epoch}:\tTrain Accuracy: {top1_train_accuracy.item():.2f}\tTest Accuracy: {top1_accuracy.item()/ (counter + 1):.2f}\tTest Top-5 Accuracy: {top5_accuracy.item()/ (counter + 1):.2f}")
    
    top1_accuracy /= (len(test_loader))
    top5_accuracy /= (len(test_loader))
    if use_wandb:
        wandb.log({"train_accuracy": top1_train_accuracy,
                   "test_accuracy": top1_accuracy,
                   "test_top5_accuracy": top5_accuracy,
                   "epoch_loss": epoch_loss})
    print(f"Epoch {epoch}:\tEpoch Loss: {epoch_loss:.2f}\tTrain Accuracy: {top1_train_accuracy.item():.2f}\tTest Accuracy: {top1_accuracy.item():.2f}\tTest Top-5 Accuracy: {top5_accuracy.item():.2f}")

Epoch 0: Loss: 0.64	Train Accuracy: 65.95: 100%|██████████| 37/37 [03:05<00:00,  5.02s/it]
Epoch 0:	Train Accuracy: 65.95	Test Accuracy: 55.94	Test Top-5 Accuracy: 100.00: 100%|██████████| 2/2 [00:37<00:00, 18.52s/it]


Epoch 0:	Epoch Loss: 0.64	Train Accuracy: 65.95	Test Accuracy: 55.94	Test Top-5 Accuracy: 100.00


Epoch 1: Loss: 0.64	Train Accuracy: 65.99: 100%|██████████| 37/37 [03:05<00:00,  5.01s/it]
Epoch 1:	Train Accuracy: 65.99	Test Accuracy: 55.94	Test Top-5 Accuracy: 100.00: 100%|██████████| 2/2 [00:36<00:00, 18.29s/it]


Epoch 1:	Epoch Loss: 0.64	Train Accuracy: 65.99	Test Accuracy: 55.94	Test Top-5 Accuracy: 100.00


Epoch 2: Loss: 0.64	Train Accuracy: 66.18:  86%|████████▋ | 32/37 [03:05<00:28,  5.79s/it]


KeyboardInterrupt: 