In [66]:
# !pip install wilds

In [67]:
import numpy as np
from wilds import get_dataset
from wilds.common.data_loaders import get_train_loader
import torchvision.transforms as transforms
import torchvision.models as models
from wilds.common.data_loaders import get_eval_loader
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
import torch.nn.functional as F
from tqdm import tqdm

In [68]:
# Load the full dataset, and download it if necessary
dataset = get_dataset(dataset="camelyon17", download=False)

In [69]:
BATCH_SIZE = 32
FRACTION = 0.33

# Get the training set
train_data = dataset.get_subset(
    "train",
    frac = FRACTION,
    transform=transforms.Compose(
        [
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])]
    ),
)

print(len(train_data)) #302436 initially

"""
# (Optional) Load unlabeled data
dataset = get_dataset(dataset="camelyon17", download=True, unlabeled=True)
unlabeled_data = dataset.get_subset(
    "test_unlabeled",
    transform=transforms.Compose(
        [transforms.Resize((448, 448)), transforms.ToTensor()]
    ),
)
unlabeled_loader = get_train_loader("standard", unlabeled_data, batch_size=16)
"""
"""
# Train loop
for labeled_batch, unlabeled_batch in zip(train_loader, unlabeled_loader):
    x, y, metadata = labeled_batch
    unlabeled_x, unlabeled_metadata = unlabeled_batch
    ...
"""

99804


'\n# Train loop\nfor labeled_batch, unlabeled_batch in zip(train_loader, unlabeled_loader):\n    x, y, metadata = labeled_batch\n    unlabeled_x, unlabeled_metadata = unlabeled_batch\n    ...\n'

In [70]:
# Get the id_val set
id_val_data = dataset.get_subset(
    "id_val",
    frac = FRACTION,
    transform=transforms.Compose(
        [
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])]
    ),
)

print(len(id_val_data))

11075


In [71]:
# Get the val set
val_data = dataset.get_subset(
    "val",
    frac = FRACTION,
    transform=transforms.Compose(
        [
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])]
    ),
)

print(len(val_data))

11518


In [72]:
# Get the test set
test_data = dataset.get_subset(
    "test",
    frac = FRACTION,
    transform=transforms.Compose(
        [
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])]
    ),
)

print(len(test_data))


28068


In [73]:
print(test_data[0])

(tensor([[[ 1.1358,  0.9474,  0.9132,  ..., -0.3369,  0.7762,  0.0741],
         [ 1.2043,  0.9646,  0.8961,  ...,  1.5982,  1.6324,  0.7933],
         [ 1.1872,  0.9303,  0.9303,  ...,  1.4783,  1.3755,  1.1358],
         ...,
         [-0.3541,  0.1426,  1.1187,  ...,  1.4269,  0.3994, -0.3541],
         [-0.7479, -0.5082,  0.4851,  ...,  1.1700, -0.3883, -0.8849],
         [-0.8849, -0.5938,  0.4337,  ...,  0.8961, -0.5596, -0.4397]],

        [[ 0.3277,  0.1001,  0.0651,  ..., -0.5301, -0.0924, -0.3200],
         [ 0.3102,  0.0826,  0.0826,  ...,  1.4657,  1.7108,  0.7829],
         [ 0.0301, -0.1450, -0.0224,  ...,  0.6779,  1.0455,  0.3277],
         ...,
         [-1.1078, -0.5826,  0.3452,  ...,  1.1331,  0.0826, -0.8452],
         [-1.2129, -0.8978, -0.3025,  ...,  1.0280, -0.5476, -1.0903],
         [-0.9853, -1.2479, -0.4951,  ...,  0.8704, -0.8452, -0.9503]],

        [[ 1.8383,  1.6988,  1.5594,  ...,  1.0365,  1.6988,  1.0714],
         [ 1.8208,  1.7163,  1.6291,  ...,  

In [74]:
print(val_data[0])

(tensor([[[ 0.2111,  0.6049,  1.5125,  ...,  0.9303,  1.0673,  0.4337],
         [ 0.2111,  0.1597,  0.5193,  ...,  1.1872,  1.2899, -0.1828],
         [ 0.2453,  0.1426,  0.2796,  ...,  1.6324,  1.2557,  0.2282],
         ...,
         [ 1.0673,  0.3138,  0.6906,  ...,  0.8276,  0.7762,  0.5707],
         [ 1.7865,  0.4679,  0.5878,  ...,  0.4851,  0.5536,  0.6563],
         [ 1.2214,  0.8447,  1.1700,  ...,  0.6221,  0.6392,  0.6049]],

        [[-0.8627, -0.1800,  1.1155,  ...,  0.3452,  0.5028, -0.1800],
         [-0.7402, -0.8277, -0.2500,  ...,  0.7304,  0.6954, -0.9678],
         [-0.5826, -0.8803, -0.5826,  ...,  1.3256,  0.7479, -0.4951],
         ...,
         [ 0.5903, -0.3375, -0.0924,  ...,  0.0126, -0.0049, -0.2325],
         [ 1.5182, -0.0574, -0.1275,  ..., -0.3550, -0.3375, -0.2325],
         [ 0.8704,  0.3452,  0.5903,  ..., -0.2325, -0.1800, -0.2325]],

        [[ 0.1302,  0.5485,  1.6465,  ...,  1.0191,  1.4374,  0.7751],
         [ 0.2522,  0.0779,  0.4265,  ...,  

In [75]:
transformed_val_data = []
for idx in range(len(val_data)):
  transformed_val_data.append((val_data[idx][0], torch.tensor(2) if val_data[idx][1].item() == 0 else torch.tensor(3), val_data[idx][2]))

In [76]:
print(transformed_val_data[0])

(tensor([[[ 0.2111,  0.6049,  1.5125,  ...,  0.9303,  1.0673,  0.4337],
         [ 0.2111,  0.1597,  0.5193,  ...,  1.1872,  1.2899, -0.1828],
         [ 0.2453,  0.1426,  0.2796,  ...,  1.6324,  1.2557,  0.2282],
         ...,
         [ 1.0673,  0.3138,  0.6906,  ...,  0.8276,  0.7762,  0.5707],
         [ 1.7865,  0.4679,  0.5878,  ...,  0.4851,  0.5536,  0.6563],
         [ 1.2214,  0.8447,  1.1700,  ...,  0.6221,  0.6392,  0.6049]],

        [[-0.8627, -0.1800,  1.1155,  ...,  0.3452,  0.5028, -0.1800],
         [-0.7402, -0.8277, -0.2500,  ...,  0.7304,  0.6954, -0.9678],
         [-0.5826, -0.8803, -0.5826,  ...,  1.3256,  0.7479, -0.4951],
         ...,
         [ 0.5903, -0.3375, -0.0924,  ...,  0.0126, -0.0049, -0.2325],
         [ 1.5182, -0.0574, -0.1275,  ..., -0.3550, -0.3375, -0.2325],
         [ 0.8704,  0.3452,  0.5903,  ..., -0.2325, -0.1800, -0.2325]],

        [[ 0.1302,  0.5485,  1.6465,  ...,  1.0191,  1.4374,  0.7751],
         [ 0.2522,  0.0779,  0.4265,  ...,  

In [77]:
# for idx in range(len(train_data)):
  # transformed_val_data.append(train_data[idx])

In [78]:
# print(len(transformed_val_data))

In [79]:
resnet18_pretrained = models.resnet18(pretrained=True)

In [80]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_epochs = 5

In [81]:
resnet18_pretrained.fc = nn.Linear(in_features=512, out_features=2, bias=True)
resnet18_pretrained.load_state_dict(torch.load("resnet18_pretrained_all_grad.pt"))
resnet18_pretrained.to(device)
def remove_classification_head(model):
    modules = list(model.children())[:-1]
    model = nn.Sequential(*modules)
    return model

resnet18_pretrained = remove_classification_head(resnet18_pretrained)

In [82]:
class SiameseNetworkDataset(Dataset):
    def __init__(self,iid_dataset, ood_dataset=None):
        self.data = iid_dataset
        self.ood_data = ood_dataset
        
    def __getitem__(self,index):
        img0_tuple = random.choice(self.data)

        #We need to approximately 50% of images to be in the same class
        should_get_same_class = random.randint(0,1) 
        if should_get_same_class == 1:
            while True:
                #Look untill the same class image is found
                img1_tuple = random.choice(self.data) 
                if img0_tuple[1] == img1_tuple[1]:
                    break
        else:
            should_get_ood = 1
            if should_get_ood == 1 and self.ood_data is not None:
                while True:
                    #Look untill a different class image is found
                    img1_tuple = random.choice(self.ood_data)
                    if img0_tuple[1] != img1_tuple[1]:
                        break
            else:
                while True:
                    #Look untill a different class image is found
                    img1_tuple = random.choice(self.data)
                    if img0_tuple[1] != img1_tuple[1]:
                        break

        img0 = img0_tuple[0]
        img1 = img1_tuple[0]
        
        return img0, img1, torch.from_numpy(np.array([int(img1_tuple[1] != img0_tuple[1])], dtype=np.float32))
    
    def __len__(self):
        return len(self.data)

In [83]:
import random

In [84]:
train_dataset = SiameseNetworkDataset(train_data, transformed_val_data)
id_val_dataset = SiameseNetworkDataset(id_val_data)
val_dataset = SiameseNetworkDataset(val_data)

In [85]:
test_dataset = SiameseNetworkDataset(test_data)

In [86]:
#create the Siamese Neural Network
class SiameseNetwork(nn.Module):

    def __init__(self):
        super(SiameseNetwork, self).__init__()

        # Setting up the Sequential of CNN Layers
        self.embed = resnet18_pretrained

        # Setting up the Fully Connected Layers
        self.fc1 = nn.Sequential(            
            nn.Linear(512,2)
        )
        
    def forward_once(self, x):
        # This function will be called for both images
        # Its output is used to determine the similiarity
        output = self.embed(x)
        output = output.view(output.size()[0], -1)
        output = self.fc1(output)
        return output

    def forward(self, input1, input2):
        # In this function we pass in both images and obtain both vectors
        # which are returned
        output1 = self.forward_once(input1)
        output2 = self.forward_once(input2)

        return output1, output2

In [87]:
# Define the Contrastive Loss Function
class ContrastiveLoss(torch.nn.Module):
    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
      # Calculate the euclidean distance and calculate the contrastive loss
      euclidean_distance = F.pairwise_distance(output1, output2, keepdim = True)

      loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
                                    (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))


      return loss_contrastive

In [88]:
train_loader = DataLoader(train_dataset,
                        shuffle=True,
                        batch_size=BATCH_SIZE)

id_val_loader = DataLoader(id_val_dataset,
                        batch_size=BATCH_SIZE)

net = SiameseNetwork().to(device)

loss_criterion = ContrastiveLoss()

optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

In [89]:
def train_epoch(model, train_dataloader, loss_crt, optimizer, device):
    """
    model: Model object
    train_dataloader: DataLoader over the training dataset
    loss_crt: loss function object
    optimizer: Optimizer object
    device: torch.device('cpu) or torch.device('cuda')

    The function returns:
     - the epoch training loss, which is an average over the individual batch
       losses
    """
    model.train()
    epoch_loss = 0.0

    num_batches = len(train_dataloader)
    # Iterate over batches
    for i, (img0, img1, label) in tqdm(enumerate(train_dataloader, 0)):

        # Send the images and labels to CUDA
        img0, img1, label = img0.to(device), img1.to(device), label.to(device)

        # Zero the gradients
        model.zero_grad()

        # Pass in the two images into the network and obtain two outputs
        output1, output2 = model(img0, img1)

        # Pass the outputs of the networks and label into the loss function
        loss_contrastive = loss_crt(output1, output2, label)

        epoch_loss += loss_contrastive.item()

        # Calculate the backpropagation
        loss_contrastive.backward()

        # Optimize
        optimizer.step()


    epoch_loss = epoch_loss/num_batches
    return epoch_loss

def eval_epoch(model, val_dataloader, loss_crt, device):
    """
    model: Model object
    val_dataloader: DataLoader over the validation dataset
    loss_crt: loss function object
    device: torch.device('cpu) or torch.device('cuda')

    The function returns:
     - the epoch validation loss, which is an average over the individual batch
       losses
    """
    model.eval()
    epoch_loss = 0.0

    num_batches = len(val_dataloader)
    with torch.no_grad():
      # Iterate over batches
      for i, (img0, img1, label) in tqdm(enumerate(val_dataloader, 0)):

          # Send the images and labels to CUDA
          img0, img1, label = img0.to(device), img1.to(device), label.to(device)

          # Pass in the two images into the network and obtain two outputs
          output1, output2 = model(img0, img1)

          # Pass the outputs of the networks and label into the loss function
          loss_contrastive = loss_crt(output1, output2, label)

          epoch_loss += loss_contrastive.item()

    epoch_loss = epoch_loss/num_batches
  
    return epoch_loss

In [90]:
train_losses = []
train_accuracies = []
id_val_losses = []
id_val_accuracies = []
for epoch in range(1, num_epochs+1):
  train_loss = train_epoch(net, train_loader, loss_criterion, optimizer, device)
  val_loss = eval_epoch(net, id_val_loader, loss_criterion, device)
  train_losses.append(train_loss)
  id_val_losses.append(val_loss)
  print('\nEpoch %d'%(epoch))
  print('train loss: %10.8f'%(train_loss))
  print('id_val loss: %10.8f'%(val_loss))

3119it [11:07,  4.67it/s]
347it [01:25,  4.07it/s]



Epoch 1
train loss: 0.18511250
id_val loss: 1.80241119


3119it [09:46,  5.31it/s]
347it [01:11,  4.84it/s]



Epoch 2
train loss: 0.09901252
id_val loss: 1.62022801


3119it [09:21,  5.55it/s]
347it [01:11,  4.89it/s]



Epoch 3
train loss: 0.04757090
id_val loss: 1.93424508


3119it [09:10,  5.67it/s]
347it [01:09,  5.00it/s]



Epoch 4
train loss: 0.04809250
id_val loss: 1.76319959


3119it [09:10,  5.67it/s]
347it [01:11,  4.89it/s]


Epoch 5
train loss: 0.02786437
id_val loss: 1.81552285





In [91]:
val_loader = DataLoader(val_dataset,
                        batch_size=BATCH_SIZE)

In [92]:
import torchvision
from matplotlib.pyplot import imshow
import torchvision.utils
# Grab one image that we are going to test
dataiter = iter(train_loader)
x0, _, label1 = next(dataiter)
dataiter = iter(id_val_loader)

dissimilarity = [[], []]

for i in range(5):
    # Iterate over 5 images and test them with the first image (x0)
    _, x1, label2 = next(dataiter)
    
    output1, output2 = net(x0.cuda(), x1.cuda())

    euclidean_distance = F.pairwise_distance(output1, output2)
    """
    talking on the phone
    """
    for idx in range(BATCH_SIZE):
      # print(label1[idx].item(), label2[idx].item())
      # print(f'Dissimilarity: {euclidean_distance[idx].item():.2f}')
      # print("\n")
      dissimilarity[1 if label1[idx].item() == label2[idx].item() else 0].append(euclidean_distance[idx].item())




In [93]:
import numpy as np
def calc_mean_std(arr):
    arr = np.asarray(arr)
    return arr.mean(), arr.std()


print(calc_mean_std(dissimilarity[0]))
print(calc_mean_std(dissimilarity[1]))

(0.10195220599847811, 0.08368997196062244)
(0.10382272435287368, 0.07654692009587087)


In [94]:
dissimilarity = [[], []]
dataiter = iter(val_loader)
for i in range(5):
    # Iterate over 5 images and test them with the first image (x0)
    _, x1, label2 = next(dataiter)
    
    output1, output2 = net(x0.cuda(), x1.cuda())

    euclidean_distance = F.pairwise_distance(output1, output2)
  
    for idx in range(BATCH_SIZE):
      # print(label1[idx].item(), label2[idx].item())
      # print(f'Dissimilarity: {euclidean_distance[idx].item():.2f}')
      dissimilarity[1 if label1[idx].item() == label2[idx].item() else 0].append(euclidean_distance[idx].item())

In [95]:
print(calc_mean_std(dissimilarity[0]))
print(calc_mean_std(dissimilarity[1]))

(1.606208580991496, 0.5829712782028255)
(1.5370301122174543, 0.6248383777325054)


In [96]:
test_loader = DataLoader(test_dataset,
                        batch_size=BATCH_SIZE)

In [97]:
dissimilarity = [[], []]
dataiter = iter(test_loader)
for i in range(5):
    # Iterate over 5 images and test them with the first image (x0)
    _, x1, label2 = next(dataiter)
    
    output1, output2 = net(x0.cuda(), x1.cuda())

    euclidean_distance = F.pairwise_distance(output1, output2)
    
    for idx in range(BATCH_SIZE):
      # print(label1[idx].item(), label2[idx].item())
      # print(f'Dissimilarity: {euclidean_distance[idx].item():.2f}')
      dissimilarity[1 if label1[idx].item() == label2[idx].item() else 0].append(euclidean_distance[idx].item())

In [98]:
"""

TODO: SAVE THE MODEL

"""

'\n\nTODO: SAVE THE MODEL\n\n'

In [99]:
print(calc_mean_std(dissimilarity[0]))
print(calc_mean_std(dissimilarity[1]))

(0.20355521790851425, 0.11737141230295087)
(0.1575345179563473, 0.11288933166877935)


In [100]:
torch.save(net.state_dict(), "contrastive_resnet18_adam.pt")

In [101]:
train_loader = get_train_loader("standard", train_data, batch_size=BATCH_SIZE)
id_val_loader = get_train_loader("standard", id_val_data, batch_size=BATCH_SIZE)
test_loader = get_train_loader("standard", test_data, batch_size=BATCH_SIZE)



In [102]:
x1, label1, _ = next(iter(id_val_loader))

net.eval()
dissimilarity = [[], []]
with torch.no_grad():
    for batch_idx, batch in tqdm(enumerate(train_loader)):
        # Iterate over 5 images and test them with the first image (x0)
        x1, label2, _ = batch
        
        output1, output2 = net(x0.cuda(), x1.cuda())
        try:
            euclidean_distance = F.pairwise_distance(output1, output2)
            
            for idx in range(BATCH_SIZE):
            # print(label1[idx].item(), label2[idx].item())
            # print(f'Dissimilarity: {euclidean_distance[idx].item():.2f}')
                dissimilarity[1 if label1[idx].item() == label2[idx].item() else 0].append(euclidean_distance[idx].item())
        except:
            pass




3119it [05:12,  9.97it/s]


In [103]:
print(calc_mean_std(dissimilarity[0]))
print(calc_mean_std(dissimilarity[1]))

(0.09867890974763098, 0.07807393506476464)
(0.09779296635776634, 0.07629760807766515)


In [104]:
x1, label1, _ = next(iter(test_loader))

net.eval()
dissimilarity = [[], []]
with torch.no_grad():
    for batch_idx, batch in tqdm(enumerate(train_loader)):
        # Iterate over 5 images and test them with the first image (x0)
        x1, label2, _ = batch
        
        output1, output2 = net(x0.cuda(), x1.cuda())
        try:
            euclidean_distance = F.pairwise_distance(output1, output2)
            
            for idx in range(BATCH_SIZE):
            # print(label1[idx].item(), label2[idx].item())
            # print(f'Dissimilarity: {euclidean_distance[idx].item():.2f}')
                dissimilarity[1 if label1[idx].item() == label2[idx].item() else 0].append(euclidean_distance[idx].item())
        except:
            pass

3119it [05:08, 10.11it/s]


In [105]:
print(calc_mean_std(dissimilarity[0]))
print(calc_mean_std(dissimilarity[1]))

(0.09837486262146662, 0.0783391900471739)
(0.0983063083832786, 0.0756569049434798)
