In [1]:
import torch
from torchvision import models
from torch import nn


In [10]:
import torch
import torch.nn as nn
import torchvision.models as models

def adapt_state_dict(state_dict):
    """
    Adapts the state dictionary's key names to match the expected keys of the ResNet model.
    """
    adapted_state_dict = {}
    for k, v in state_dict.items():
        # Remove the prefixed numbers from the key names
        new_key = '.'.join(k.split('.')[1:])
        adapted_state_dict[new_key] = v
    return adapted_state_dict

class LinearClassifier(nn.Module):
    def __init__(self, num_classes, checkpoint_path, nn_model='resnet18', pretrained=True):
        super(LinearClassifier, self).__init__()
        self.nn_model = nn_model
        self.rgb_encoder = self.create_resnet_encoder(3)
        self.tactile_encoder = self.create_resnet_encoder(6)
        
        if pretrained:
            # Load the checkpoint
            checkpoint = torch.load(checkpoint_path)
            
            # Adapt the state dictionary key names
            adapted_rgb_state_dict = adapt_state_dict(checkpoint['state_dict_vis'])
            adapted_tactile_state_dict = adapt_state_dict(checkpoint['state_dict_tac'])
            
            # Load the state dict for the visual and tactile encoders
            self.rgb_encoder.load_state_dict(adapted_rgb_state_dict, strict=False)
            self.tactile_encoder.load_state_dict(adapted_tactile_state_dict, strict=False)
            
            # Freeze the weights of the encoders
            for param in self.rgb_encoder.parameters():
                param.requires_grad = False
            for param in self.tactile_encoder.parameters():
                param.requires_grad = False
        
        # Assuming the output features of both encoders are of size 512 (e.g., for ResNet-18)
        # Adjust this if the size is different
        self.linear_layer = nn.Linear(512 * 2, num_classes)
    
    def create_resnet_encoder(self, n_channels):
        """Create a ResNet encoder based on the specified model type."""
        if self.nn_model == 'resnet18':
            resnet = models.resnet18(pretrained=False)
        elif self.nn_model == 'resnet50':
            resnet = models.resnet50(pretrained=False)
        if n_channels != 3:
            resnet.conv1 = nn.Conv2d(n_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        features = list(resnet.children())[:-2]  # Exclude the avgpool and fc layers
        features.append(nn.AdaptiveAvgPool2d((1, 1)))
        features.append(nn.Flatten())
        return nn.Sequential(*features)

    def forward(self, rgb_input, tactile_input):
        rgb_features = self.rgb_encoder(rgb_input)
        tactile_features = self.tactile_encoder(tactile_input)
        
        # Concatenate the features from both encoders
        combined_features = torch.cat((rgb_features, tactile_features), dim=1)
        
        return self.linear_layer(combined_features)


In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint_path = 'runs/Oct02_16-23-44_cpsadmin-Z790-AORUS-ELITE-AX/model_6_best_object_wise.pth'

linear_classifier = LinearClassifier(num_classes=10, checkpoint_path=checkpoint_path, nn_model='resnet18', pretrained=True)



In [18]:
from data_aug.contrastive_learning_dataset import ContrastiveLearningDataset
batch_size = 256
num_workers = 16
dataset = ContrastiveLearningDataset(root_folder='calandra_objects_split_object_wise')
train_dataset = dataset.get_dataset('calandra_label_train', 2)
test_dataset = dataset.get_dataset('calandra_label_test', 2,)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                                           num_workers=num_workers, drop_last=False, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False,  
                                            num_workers=num_workers, drop_last=False, pin_memory=True)

In [20]:
# # plot a few testing image triplets
# import matplotlib.pyplot as plt
# import numpy as np
# import torchvision
# 
# def imshow(img):
#     img = img / 2 + 0.5     # unnormalize
#     npimg = img.cpu().numpy()
#     plt.imshow(np.transpose(npimg, (1, 2, 0)))
#     plt.show()
#     
# # get some random training images
# rgb_image_q, rgb_image_k, stacked_gelsight_images_q, stacked_gelsight_images_k, label = next(iter(train_loader))
# 
# # unstack the gelsight images
# gelsightA_image_q, gelsightB_image_q = torch.chunk(stacked_gelsight_images_q, 2, dim=1)
# 
# # show image in a grid
# imshow(torchvision.utils.make_grid(rgb_image_q))
# imshow(torchvision.utils.make_grid(gelsightA_image_q))
# imshow(torchvision.utils.make_grid(gelsightB_image_q))
# 
# # show the label
# print(label)


In [21]:
optimizer = torch.optim.Adam(linear_classifier.parameters(), lr=3e-4, weight_decay=0.0008)
criterion = torch.nn.CrossEntropyLoss().to(device)

In [None]:
from utils import accuracy

epochs = 20
for epoch in range(epochs):
    top1_train_accuracy = 0
    for counter, data in enumerate(train_loader):
        rgb_image_q, _, stacked_gelsight_images_q, _, label = data
        
        logits = linear_classifier(rgb_image_q, stacked_gelsight_images_q)
        loss = criterion(logits, label)
        top1 = accuracy(logits, label, topk=(1,))
        top1_train_accuracy += top1[0]
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    top1_train_accuracy /= (counter + 1)
    top1_accuracy = 0
    top5_accuracy = 0
    for counter, data in enumerate(test_loader):
        rgb_image_q, _, stacked_gelsight_images_q, _, label = data
        
        rgb_image_q = rgb_image_q.to(device)
        stacked_gelsight_images_q = stacked_gelsight_images_q.to(device)
        label = label.to(device)
        
        logits = linear_classifier(rgb_image_q, stacked_gelsight_images_q)
        
        top1, top5 = accuracy(logits, label, topk=(1,5))
        top1_accuracy += top1[0]
        top5_accuracy += top5[0]
    
    top1_accuracy /= (counter + 1)
    top5_accuracy /= (counter + 1)
    print(f"Epoch {epoch}:\tTrain Accuracy: {top1_train_accuracy.item():.2f}\tTest Accuracy: {top1_accuracy.item():.2f}\tTest Top-5 Accuracy: {top5_accuracy.item():.2f}")
