In [1]:
# !pip install wilds

In [2]:
import numpy as np
from wilds import get_dataset
from wilds.common.data_loaders import get_train_loader
import torchvision.transforms as transforms
import torchvision.models as models
from wilds.common.data_loaders import get_eval_loader
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
import torch.nn.functional as F
from tqdm import tqdm

In [3]:
# Load the full dataset, and download it if necessary
dataset = get_dataset(dataset="camelyon17", download=False)

In [4]:
BATCH_SIZE = 32
FRACTION = 0.33

# Get the training set
train_data = dataset.get_subset(
    "train",
    frac = FRACTION,
    transform=transforms.Compose(
        [
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])]
    ),
)

print(len(train_data)) #302436 initially

"""
# (Optional) Load unlabeled data
dataset = get_dataset(dataset="camelyon17", download=True, unlabeled=True)
unlabeled_data = dataset.get_subset(
    "test_unlabeled",
    transform=transforms.Compose(
        [transforms.Resize((448, 448)), transforms.ToTensor()]
    ),
)
unlabeled_loader = get_train_loader("standard", unlabeled_data, batch_size=16)
"""
"""
# Train loop
for labeled_batch, unlabeled_batch in zip(train_loader, unlabeled_loader):
    x, y, metadata = labeled_batch
    unlabeled_x, unlabeled_metadata = unlabeled_batch
    ...
"""

99804


'\n# Train loop\nfor labeled_batch, unlabeled_batch in zip(train_loader, unlabeled_loader):\n    x, y, metadata = labeled_batch\n    unlabeled_x, unlabeled_metadata = unlabeled_batch\n    ...\n'

In [5]:
# Get the id_val set
id_val_data = dataset.get_subset(
    "id_val",
    frac = FRACTION,
    transform=transforms.Compose(
        [
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])]
    ),
)

print(len(id_val_data))

11075


In [6]:
# Get the val set
val_data = dataset.get_subset(
    "val",
    frac = FRACTION,
    transform=transforms.Compose(
        [
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])]
    ),
)

print(len(val_data))

11518


In [7]:
# Get the test set
test_data = dataset.get_subset(
    "test",
    frac = FRACTION,
    transform=transforms.Compose(
        [
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])]
    ),
)

print(len(test_data))


28068


In [8]:
print(test_data[0])

(tensor([[[ 0.2282,  0.0569,  0.6221,  ...,  0.1083,  0.1254,  0.0741],
         [ 0.1254,  0.1426,  0.9474,  ...,  0.2796,  0.3309,  0.1768],
         [-0.2171,  0.0398,  1.3413,  ...,  0.1426,  0.2453,  0.0227],
         ...,
         [-0.1657,  0.9817,  1.5982,  ...,  1.3070,  1.4612,  1.5297],
         [-0.2856,  0.6221,  1.5982,  ...,  1.2899,  1.2385,  1.7865],
         [ 0.5364,  0.4851,  1.0673,  ...,  1.0844,  1.1700,  1.1872]],

        [[-0.6527, -0.4776, -0.5476,  ..., -0.4601, -0.5826, -0.4251],
         [-0.6877, -0.6702, -0.2325,  ..., -0.3375, -0.4601, -0.4251],
         [-0.8452, -0.7052,  0.0476,  ..., -0.4251, -0.4951, -0.5476],
         ...,
         [-1.0728,  0.2402,  1.0805,  ...,  0.1877,  0.3978,  0.9405],
         [-0.7927, -0.3025,  0.7654,  ...,  0.1352,  0.2577,  0.2752],
         [-0.3375, -0.3200,  0.6954,  ..., -0.0399,  0.0476,  0.0301]],

        [[ 1.1411,  1.2457,  1.2805,  ...,  1.2457,  1.3328,  1.3154],
         [ 1.2108,  1.2108,  1.3154,  ...,  

In [9]:
print(val_data[0])

(tensor([[[ 0.6049,  0.7591,  0.8276,  ...,  0.3823,  0.3652,  0.5536],
         [ 0.6392,  0.5536,  0.5022,  ...,  0.4166,  0.3823,  0.4679],
         [ 0.4679,  0.2624,  0.1426,  ...,  0.2967,  0.3138,  0.5364],
         ...,
         [ 0.1939,  0.2453,  0.0741,  ..., -0.1314, -0.1143,  0.1083],
         [ 0.0398,  0.0741, -0.0629,  ..., -0.2171, -0.0458,  0.0741],
         [-0.1657, -0.1657, -0.2684,  ...,  0.0569,  0.3309,  0.3994]],

        [[-0.2500, -0.0049,  0.0651,  ..., -0.7927, -0.7052, -0.3200],
         [-0.3025, -0.2850, -0.2850,  ..., -0.7577, -0.7577, -0.4776],
         [-0.4601, -0.6176, -0.7577,  ..., -0.7227, -0.7227, -0.4251],
         ...,
         [-0.8102, -0.8277, -0.9153,  ..., -1.0028, -1.0378, -0.8277],
         [-0.8803, -0.9153, -1.0028,  ..., -1.0203, -0.8277, -0.6702],
         [-1.0728, -1.1429, -1.2129,  ..., -0.8452, -0.4776, -0.3200]],

        [[ 0.6705,  0.8099,  0.8797,  ...,  0.2173,  0.2696,  0.5659],
         [ 0.6008,  0.5659,  0.5485,  ...,  

In [10]:
transformed_val_data = []
for idx in range(len(val_data)):
  transformed_val_data.append((val_data[idx][0], torch.tensor(2) if val_data[idx][1].item() == 0 else torch.tensor(3), val_data[idx][2]))

In [11]:
print(transformed_val_data[0])

(tensor([[[ 0.6049,  0.7591,  0.8276,  ...,  0.3823,  0.3652,  0.5536],
         [ 0.6392,  0.5536,  0.5022,  ...,  0.4166,  0.3823,  0.4679],
         [ 0.4679,  0.2624,  0.1426,  ...,  0.2967,  0.3138,  0.5364],
         ...,
         [ 0.1939,  0.2453,  0.0741,  ..., -0.1314, -0.1143,  0.1083],
         [ 0.0398,  0.0741, -0.0629,  ..., -0.2171, -0.0458,  0.0741],
         [-0.1657, -0.1657, -0.2684,  ...,  0.0569,  0.3309,  0.3994]],

        [[-0.2500, -0.0049,  0.0651,  ..., -0.7927, -0.7052, -0.3200],
         [-0.3025, -0.2850, -0.2850,  ..., -0.7577, -0.7577, -0.4776],
         [-0.4601, -0.6176, -0.7577,  ..., -0.7227, -0.7227, -0.4251],
         ...,
         [-0.8102, -0.8277, -0.9153,  ..., -1.0028, -1.0378, -0.8277],
         [-0.8803, -0.9153, -1.0028,  ..., -1.0203, -0.8277, -0.6702],
         [-1.0728, -1.1429, -1.2129,  ..., -0.8452, -0.4776, -0.3200]],

        [[ 0.6705,  0.8099,  0.8797,  ...,  0.2173,  0.2696,  0.5659],
         [ 0.6008,  0.5659,  0.5485,  ...,  

In [12]:
# for idx in range(len(train_data)):
  # transformed_val_data.append(train_data[idx])

In [13]:
# print(len(transformed_val_data))

In [14]:
resnet18_pretrained = models.resnet18(pretrained=True)

In [15]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_epochs = 5

In [16]:
resnet18_pretrained.fc = nn.Linear(in_features=512, out_features=2, bias=True)
resnet18_pretrained.to(device)
# resnet18_pretrained.load_state_dict(torch.load("resnet18_pretrained_all_grad.pt"))
def remove_classification_head(model):
    modules = list(model.children())[:-1]
    model = nn.Sequential(*modules)
    return model

resnet18_pretrained = remove_classification_head(resnet18_pretrained)

In [17]:
class SiameseNetworkDataset(Dataset):
    def __init__(self,iid_dataset, ood_dataset=None):
        self.data = iid_dataset
        self.ood_data = ood_dataset
        
    def __getitem__(self,index):
        img0_tuple = random.choice(self.data)

        #We need to approximately 50% of images to be in the same class
        should_get_same_class = random.randint(0,1) 
        if should_get_same_class == 1:
            while True:
                #Look untill the same class image is found
                img1_tuple = random.choice(self.data) 
                if img0_tuple[1] == img1_tuple[1]:
                    break
        else:
            should_get_ood = random.randint(0, 1)
            if should_get_ood == 1 and self.ood_data is not None:
                while True:
                    #Look untill a different class image is found
                    img1_tuple = random.choice(self.ood_data)
                    if img0_tuple[1] != img1_tuple[1]:
                        break
            else:
                while True:
                    #Look untill a different class image is found
                    img1_tuple = random.choice(self.data)
                    if img0_tuple[1] != img1_tuple[1]:
                        break

        img0 = img0_tuple[0]
        img1 = img1_tuple[0]
        
        return img0, img1, torch.from_numpy(np.array([int(img1_tuple[1] != img0_tuple[1])], dtype=np.float32))
    
    def __len__(self):
        return len(self.data)

In [18]:
import random

In [19]:
train_dataset = SiameseNetworkDataset(train_data, transformed_val_data)
id_val_dataset = SiameseNetworkDataset(id_val_data)
val_dataset = SiameseNetworkDataset(val_data)

In [20]:
test_dataset = SiameseNetworkDataset(test_data)

In [21]:
#create the Siamese Neural Network
class SiameseNetwork(nn.Module):

    def __init__(self):
        super(SiameseNetwork, self).__init__()

        # Setting up the Sequential of CNN Layers
        self.embed = resnet18_pretrained

        # Setting up the Fully Connected Layers
        self.fc1 = nn.Sequential(            
            nn.Linear(512,2)
        )
        
    def forward_once(self, x):
        # This function will be called for both images
        # Its output is used to determine the similiarity
        output = self.embed(x)
        output = output.view(output.size()[0], -1)
        output = self.fc1(output)
        return output

    def forward(self, input1, input2):
        # In this function we pass in both images and obtain both vectors
        # which are returned
        output1 = self.forward_once(input1)
        output2 = self.forward_once(input2)

        return output1, output2

In [22]:
# Define the Contrastive Loss Function
class ContrastiveLoss(torch.nn.Module):
    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
      # Calculate the euclidean distance and calculate the contrastive loss
      euclidean_distance = F.pairwise_distance(output1, output2, keepdim = True)

      loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
                                    (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))


      return loss_contrastive

In [23]:
train_loader = DataLoader(train_dataset,
                        shuffle=True,
                        batch_size=BATCH_SIZE)

id_val_loader = DataLoader(id_val_dataset,
                        batch_size=BATCH_SIZE)

net = SiameseNetwork().to(device)

loss_criterion = ContrastiveLoss()

optimizer = torch.optim.SGD(resnet18_pretrained.parameters(), lr=0.01, momentum=0.9)

In [24]:
def train_epoch(model, train_dataloader, loss_crt, optimizer, device):
    """
    model: Model object
    train_dataloader: DataLoader over the training dataset
    loss_crt: loss function object
    optimizer: Optimizer object
    device: torch.device('cpu) or torch.device('cuda')

    The function returns:
     - the epoch training loss, which is an average over the individual batch
       losses
    """
    model.train()
    epoch_loss = 0.0

    num_batches = len(train_dataloader)
    # Iterate over batches
    for i, (img0, img1, label) in tqdm(enumerate(train_dataloader, 0)):

        # Send the images and labels to CUDA
        img0, img1, label = img0.to(device), img1.to(device), label.to(device)

        # Zero the gradients
        model.zero_grad()

        # Pass in the two images into the network and obtain two outputs
        output1, output2 = model(img0, img1)

        # Pass the outputs of the networks and label into the loss function
        loss_contrastive = loss_crt(output1, output2, label)

        epoch_loss += loss_contrastive.item()

        # Calculate the backpropagation
        loss_contrastive.backward()

        # Optimize
        optimizer.step()


    epoch_loss = epoch_loss/num_batches
    return epoch_loss

def eval_epoch(model, val_dataloader, loss_crt, device):
    """
    model: Model object
    val_dataloader: DataLoader over the validation dataset
    loss_crt: loss function object
    device: torch.device('cpu) or torch.device('cuda')

    The function returns:
     - the epoch validation loss, which is an average over the individual batch
       losses
    """
    model.eval()
    epoch_loss = 0.0

    num_batches = len(val_dataloader)
    with torch.no_grad():
      # Iterate over batches
      for i, (img0, img1, label) in tqdm(enumerate(val_dataloader, 0)):

          # Send the images and labels to CUDA
          img0, img1, label = img0.to(device), img1.to(device), label.to(device)

          # Pass in the two images into the network and obtain two outputs
          output1, output2 = model(img0, img1)

          # Pass the outputs of the networks and label into the loss function
          loss_contrastive = loss_crt(output1, output2, label)

          epoch_loss += loss_contrastive.item()

    epoch_loss = epoch_loss/num_batches
  
    return epoch_loss

In [25]:
train_losses = []
train_accuracies = []
id_val_losses = []
id_val_accuracies = []
for epoch in range(1, num_epochs+1):
  train_loss = train_epoch(net, train_loader, loss_criterion, optimizer, device)
  val_loss = eval_epoch(net, id_val_loader, loss_criterion, device)
  train_losses.append(train_loss)
  id_val_losses.append(val_loss)
  print('\nEpoch %d'%(epoch))
  print('train loss: %10.8f'%(train_loss))
  print('id_val loss: %10.8f'%(val_loss))

2396it [51:19,  1.29s/it]


KeyboardInterrupt: 

In [None]:
val_loader = DataLoader(val_dataset,
                        batch_size=BATCH_SIZE)

In [None]:
import torchvision
from matplotlib.pyplot import imshow
import torchvision.utils
# Grab one image that we are going to test
dataiter = iter(train_loader)
x0, _, label1 = next(dataiter)
dataiter = iter(id_val_loader)

dissimilarity = [[], []]

for i in range(5):
    # Iterate over 5 images and test them with the first image (x0)
    _, x1, label2 = next(dataiter)
    
    output1, output2 = net(x0.cuda(), x1.cuda())

    euclidean_distance = F.pairwise_distance(output1, output2)
    """
    talking on the phone
    """
    for idx in range(BATCH_SIZE):
      # print(label1[idx].item(), label2[idx].item())
      # print(f'Dissimilarity: {euclidean_distance[idx].item():.2f}')
      # print("\n")
      dissimilarity[1 if label1[idx].item() == label2[idx].item() else 0].append(euclidean_distance[idx].item())




In [None]:
import numpy as np
def calc_mean_std(arr):
    arr = np.asarray(arr)
    return arr.mean(), arr.std()


print(calc_mean_std(dissimilarity[0]))
print(calc_mean_std(dissimilarity[1]))

In [None]:
dissimilarity = [[], []]
dataiter = iter(val_loader)
for i in range(5):
    # Iterate over 5 images and test them with the first image (x0)
    _, x1, label2 = next(dataiter)
    
    output1, output2 = net(x0.cuda(), x1.cuda())

    euclidean_distance = F.pairwise_distance(output1, output2)
  
    for idx in range(BATCH_SIZE):
      # print(label1[idx].item(), label2[idx].item())
      # print(f'Dissimilarity: {euclidean_distance[idx].item():.2f}')
      dissimilarity[1 if label1[idx].item() == label2[idx].item() else 0].append(euclidean_distance[idx].item())

In [None]:
print(calc_mean_std(dissimilarity[0]))
print(calc_mean_std(dissimilarity[1]))

In [None]:
test_loader = DataLoader(test_dataset,
                        batch_size=BATCH_SIZE)

In [None]:
dissimilarity = [[], []]
dataiter = iter(test_loader)
for i in range(5):
    # Iterate over 5 images and test them with the first image (x0)
    _, x1, label2 = next(dataiter)
    
    output1, output2 = net(x0.cuda(), x1.cuda())

    euclidean_distance = F.pairwise_distance(output1, output2)
    
    for idx in range(BATCH_SIZE):
      # print(label1[idx].item(), label2[idx].item())
      # print(f'Dissimilarity: {euclidean_distance[idx].item():.2f}')
      dissimilarity[1 if label1[idx].item() == label2[idx].item() else 0].append(euclidean_distance[idx].item())

In [None]:
torch.save(net.state_dict(), "contrastive_resnet18_sgd.pt")

In [None]:
print(calc_mean_std(dissimilarity[0]))
print(calc_mean_std(dissimilarity[1]))