In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import pickle
import torch.nn as nn
from sklearn.metrics import roc_auc_score
from timm.models.vision_transformer import vit_base_patch16_224
import torch.optim as optim
from PIL import Image
from torchvision.transforms import functional as F
from torch.utils.data import DataLoader, random_split, Dataset
from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedKFold
import wandb
import timm


In [None]:
device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
#create a random tensor of shape (2,100,100)
x= torch.rand(2, 100, 100)
x_i= torch.rand(2, 100, 100)
x_j= torch.rand(2, 100, 100)

In [None]:
transform= transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
])

#create a custom dataset and add one augmentation
class CustomDataset():
    def __init__(self,matrix,transform=None):
        self.matrix=matrix
        #transform the data
        self.transform=transform

    def __len__(self):
        return len(self.matrix)
    
    def __getitem__(self, index):
        matrix= self.matrix[index]
        if self.transform:
            x_i= self.transform(matrix)
            x_j= self.transform(matrix)

        return matrix,x_i,x_j

In [None]:

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadSelfAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.head_dim = d_model // num_heads
        assert self.head_dim * num_heads == d_model, "d_model must be divisible by num_heads"
        self.qkv = nn.Linear(d_model, d_model * 3, bias=False)
        self.fc_out = nn.Linear(d_model, d_model)

    def forward(self, x):
        batch_size, seq_length, _ = x.size()
        qkv = self.qkv(x).view(batch_size, seq_length, self.num_heads, 3 * self.head_dim)
        qkv = qkv.permute(0, 2, 1, 3).contiguous().view(batch_size * self.num_heads, seq_length, 3 * self.head_dim)
        queries, keys, values = qkv.chunk(3, dim=-1)
        energy = torch.matmul(queries, keys.transpose(-2, -1)) / self.head_dim ** 0.5
        attention = torch.softmax(energy, dim=-1)
        out = torch.matmul(attention, values)
        out = out.view(batch_size, self.num_heads, seq_length, self.head_dim)
        out = out.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_length, self.d_model)
        out = self.fc_out(out)

        return out

In [None]:
#create MLP model with 2 hidden layers
class Network(nn.Module):
    def __init__(self, input_size=100):
        super(Network, self).__init__()
        multi_head_self_attention = MultiHeadSelfAttention(d_model=100, num_heads=10)
        self.fc1 = nn.Linear(100, 2)
        
    def forward(self, x):
        x = self.fc1(x)
        return x
    

In [None]:
class Reconstructive_net(nn.Module):
    def __init__(self, input_size=2):
        super(Reconstructive_net, self).__init__()
        self.fc1 = nn.Linear(2, 100)
    
    def forward(self, x):
        x = self.fc1(x)
        return x

In [None]:
class Contrastive_net(nn.Module):
    def __init__(self, input_size=2):
        super(Contrastive_net, self).__init__()
        self.fc1 = nn.Flatten()
        self.fc2 = nn.Linear(200, 8)
        self.norm1= nn.LayerNorm(8)
    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.norm1(x)
        return x

In [None]:
main_model=Network()

In [None]:
contrastive_model=Contrastive_net()

In [None]:
reconstructive_model = Reconstructive_net()

In [None]:
main_model=main_model.to(device)
contrastive_model=contrastive_model.to(device)
reconstructive_model=reconstructive_model.to(device)


In [None]:
# embeding_i=main_model(x_i)
# embeding_j=main_model(x_j)

In [None]:
# x_recon_i=reconstructive_model(embeding_i)
# x_recon_j=reconstructive_model(embeding_j)

In [None]:
# x_contrast_i=contrastive_model(embeding_i)
# x_contrast_j=contrastive_model(embeding_j)

In [None]:
class SimCLR_Loss(nn.Module):
    def __init__(self, batch_size, temperature):
        super().__init__()
        self.batch_size = batch_size
        self.temperature = temperature

        self.mask = self.mask_correlated_samples(batch_size)
        self.criterion = nn.CrossEntropyLoss(reduction="sum")
        self.similarity_f = nn.CosineSimilarity(dim=2)

    def mask_correlated_samples(self, batch_size):
        N = 2 * batch_size
        mask = torch.ones((N, N), dtype=bool)
        mask = mask.fill_diagonal_(0)
        
        for i in range(batch_size):
            mask[i, batch_size + i] = 0
            mask[batch_size + i, i] = 0
        return mask

    def forward(self, z_i, z_j):
        N = 2 * self.batch_size
        z = torch.cat((z_i, z_j), dim=0)
        sim = self.similarity_f(z.unsqueeze(1), z.unsqueeze(0)) / self.temperature
        # print("sim:",sim.shape)
        sim_i_j = torch.diag(sim, self.batch_size)
        sim_j_i = torch.diag(sim, -self.batch_size)
        # print("sim_i_j")
        # print(sim_i_j.shape)
        # print(sim_j_i)
        # print("sim_j_i")
        # print(sim_j_i.shape)
        # print(sim_j_i)
        # We have 2N samples, but with Distributed training every GPU gets N examples too, resulting in: 2xNxN
        positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1)
        negative_samples = sim[self.mask].reshape(N, -1)
        #SIMCLR
        labels = torch.from_numpy(np.array([0]*N)).reshape(-1).to(positive_samples.device).long() #.float()
        logits = torch.cat((positive_samples, negative_samples), dim=1)
        loss = self.criterion(logits, labels)
        loss /= N
        return loss


In [None]:
contrastive_loss=SimCLR_Loss(32, 0.5)

In [None]:
# loss_contrastive=contrastive_model(x_contrast_i, x_contrast_j)

In [None]:
data=torch.rand(50,100,100)

In [None]:
dataset=CustomDataset(data,transform=transform)

In [None]:
dataloader=DataLoader(dataset,batch_size=32,shuffle=True,drop_last=True)

In [None]:
optimizer=optim.Adam(main_model.parameters(),lr=0.001)

In [None]:
criterion=nn.MSELoss()

In [None]:
for i in range(0,10):
    for x,x_i,x_j in dataloader:
        optimizer.zero_grad()
        x=x.to(device)
        x_i=x_i.to(device)
        x_j=x_j.to(device)
        x=x.squeeze()
        x_i=x_i.squeeze()
        x_j=x_j.squeeze()
        # print(x_i.shape)
        # print(x_j.shape)
        # print("=====")
        # print("main")
        embeding_i=main_model(x_i)
        embeding_j=main_model(x_j)
        # print(embeding_i.shape)
        # print(embeding_j.shape)
        # print("=====")
        # print("recon")
        x_recon_i=reconstructive_model(embeding_i)
        x_recon_j=reconstructive_model(embeding_j)
        # print(x_recon_i.shape)
        # print(x_recon_j.shape)
        # print("=====")

        # print("contrast")
        x_contrast_i=contrastive_model(embeding_i)
        x_contrast_j=contrastive_model(embeding_j)
        # print(x_contrast_i.shape)
        # print(x_contrast_j.shape)
        # print("=====")
        print("x",x.shape)
        loss_contrastive=contrastive_loss(x_contrast_i, x_contrast_j)
        loss_reconstructive= criterion(x_recon_i,x)+criterion(x_recon_j,x)
        loss=loss_contrastive+loss_reconstructive
        loss.backward()

        print(loss_contrastive,loss_reconstructive)
        print("loss:",loss)
        optimizer.step()



In [None]:
sim_i_j=torch.rand(30)
sim_j_i=torch.rand(30)
N=60
torch.cat((sim_i_j, sim_j_i), dim=0).reshape(60, 1)