In [1]:
# !pip install wilds

In [2]:
import numpy as np
from wilds import get_dataset
from wilds.common.data_loaders import get_train_loader
import torchvision.transforms as transforms
import torchvision.models as models
from wilds.common.data_loaders import get_eval_loader
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
import torch.nn.functional as F
from tqdm import tqdm
import random

In [3]:
# Load the full dataset, and download it if necessary
dataset = get_dataset(dataset="camelyon17", download=True)

Downloading dataset to data/camelyon17_v1.0...
You can also download the dataset manually at https://wilds.stanford.edu/downloads.
Downloading https://worksheets.codalab.org/rest/bundles/0xe45e15f39fb54e9d9e919556af67aabe/contents/blob/ to data/camelyon17_v1.0/archive.tar.gz


  0%|          | 0/10658709504 [00:00<?, ?Byte/s]

Extracting data/camelyon17_v1.0/archive.tar.gz to data/camelyon17_v1.0

It took 59.83 minutes to download and uncompress the dataset.



In [4]:
BATCH_SIZE = 32
FRACTION = 0.33

# Get the training set
train_data = dataset.get_subset(
    "train",
    frac = FRACTION,
    transform=transforms.Compose(
        [
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])]
    ),
)

print(len(train_data)) #302436 initially


99804


In [5]:
# Get the id_val set
id_val_data = dataset.get_subset(
    "id_val",
    frac = FRACTION,
    transform=transforms.Compose(
        [
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])]
    ),
)

print(len(id_val_data))

11075


In [6]:
# Get the val set
val_data = dataset.get_subset(
    "val",
    frac = FRACTION,
    transform=transforms.Compose(
        [
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])]
    ),
)

print(len(val_data))

11518


In [7]:
# Get the test set
test_data = dataset.get_subset(
    "test",
    frac = FRACTION,
    transform=transforms.Compose(
        [
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])]
    ),
)

print(len(test_data))


28068


In [8]:
resnet18_pretrained = models.resnet18(pretrained=True)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_epochs = 5

In [10]:
resnet18_pretrained.fc = nn.Linear(in_features=512, out_features=2, bias=True)
# resnet18_pretrained.load_state_dict(torch.load("resnet18_pretrained_all_grad.pt"))
resnet18_pretrained.to(device)
def remove_classification_head(model):
    modules = list(model.children())[:-1]
    model = nn.Sequential(*modules)
    return model

resnet18_pretrained = remove_classification_head(resnet18_pretrained)

In [11]:
class SiameseNetworkDataset(Dataset):
    def __init__(self,iid_dataset, ood_dataset=None):
        self.data = iid_dataset
        self.ood_data = ood_dataset
        
    def __getitem__(self,index):
        img0_tuple = random.choice(self.data)
        
        should_get_same_distribution = random.randint(0,1) # 50% chance
        if should_get_same_distribution == 1:
          img1_tuple = random.choice(self.data) 
                
        else:
          img1_tuple = random.choice(self.ood_data)

        img0 = img0_tuple[0]
        img1 = img1_tuple[0]
        
        return img0, img1, torch.from_numpy(np.array([should_get_same_distribution], dtype=np.float32))
    
    def __len__(self):
        return len(self.data)

In [12]:
train_dataset = SiameseNetworkDataset(train_data, test_data)
id_val_dataset = SiameseNetworkDataset(id_val_data, val_data)

In [13]:
#create the Siamese Neural Network
class SiameseNetwork(nn.Module):

    def __init__(self):
        super(SiameseNetwork, self).__init__()

        # Setting up the Sequential of CNN Layers
        self.embed = resnet18_pretrained

        # Setting up the Fully Connected Layers
        self.fc1 = nn.Sequential(            
            nn.Linear(512,2)
        )
        
    def forward_once(self, x):
        # This function will be called for both images
        # Its output is used to determine the similiarity
        output = self.embed(x)
        output = output.view(output.size()[0], -1)
        output = self.fc1(output)
        return output

    def forward(self, input1, input2):
        # In this function we pass in both images and obtain both vectors
        # which are returned
        output1 = self.forward_once(input1)
        output2 = self.forward_once(input2)

        return output1, output2

In [14]:
# Define the Contrastive Loss Function
class ContrastiveLoss(torch.nn.Module):
    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
      # Calculate the euclidean distance and calculate the contrastive loss
      euclidean_distance = F.pairwise_distance(output1, output2, keepdim = True)

      loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
                                    (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))


      return loss_contrastive

In [15]:
train_loader = DataLoader(train_dataset,
                        shuffle=True,
                        batch_size=BATCH_SIZE)

id_val_loader = DataLoader(id_val_dataset,
                        batch_size=BATCH_SIZE)

net = SiameseNetwork().to(device)

loss_criterion = ContrastiveLoss()

optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

In [16]:
def train_epoch(model, train_dataloader, loss_crt, optimizer, device):
    """
    model: Model object
    train_dataloader: DataLoader over the training dataset
    loss_crt: loss function object
    optimizer: Optimizer object
    device: torch.device('cpu) or torch.device('cuda')

    The function returns:
     - the epoch training loss, which is an average over the individual batch
       losses
    """
    model.train()
    epoch_loss = 0.0

    num_batches = len(train_dataloader)
    # Iterate over batches
    for i, (img0, img1, label) in tqdm(enumerate(train_dataloader, 0)):

        # Send the images and labels to CUDA
        img0, img1, label = img0.to(device), img1.to(device), label.to(device)

        # Zero the gradients
        model.zero_grad()

        # Pass in the two images into the network and obtain two outputs
        output1, output2 = model(img0, img1)

        # Pass the outputs of the networks and label into the loss function
        loss_contrastive = loss_crt(output1, output2, label)

        epoch_loss += loss_contrastive.item()

        # Calculate the backpropagation
        loss_contrastive.backward()

        # Optimize
        optimizer.step()


    epoch_loss = epoch_loss/num_batches
    return epoch_loss

def eval_epoch(model, val_dataloader, loss_crt, device):
    """
    model: Model object
    val_dataloader: DataLoader over the validation dataset
    loss_crt: loss function object
    device: torch.device('cpu) or torch.device('cuda')

    The function returns:
     - the epoch validation loss, which is an average over the individual batch
       losses
    """
    model.eval()
    epoch_loss = 0.0

    num_batches = len(val_dataloader)
    with torch.no_grad():
      # Iterate over batches
      for i, (img0, img1, label) in tqdm(enumerate(val_dataloader, 0)):

          # Send the images and labels to CUDA
          img0, img1, label = img0.to(device), img1.to(device), label.to(device)

          # Pass in the two images into the network and obtain two outputs
          output1, output2 = model(img0, img1)

          # Pass the outputs of the networks and label into the loss function
          loss_contrastive = loss_crt(output1, output2, label)

          epoch_loss += loss_contrastive.item()

    epoch_loss = epoch_loss/num_batches
  
    return epoch_loss

In [17]:
train_losses = []
train_accuracies = []
id_val_losses = []
id_val_accuracies = []
for epoch in range(1, num_epochs+1):
  train_loss = train_epoch(net, train_loader, loss_criterion, optimizer, device)
  val_loss = eval_epoch(net, id_val_loader, loss_criterion, device)
  train_losses.append(train_loss)
  id_val_losses.append(val_loss)
  print('\nEpoch %d'%(epoch))
  print('train loss: %10.8f'%(train_loss))
  print('id_val loss: %10.8f'%(val_loss))

3119it [11:21,  4.58it/s]
347it [01:16,  4.53it/s]



Epoch 1
train loss: 0.13897674
id_val loss: 1.61361878


3119it [09:54,  5.24it/s]
347it [01:03,  5.49it/s]



Epoch 2
train loss: 0.08642494
id_val loss: 2.23099485


3119it [09:18,  5.59it/s]
347it [01:00,  5.75it/s]



Epoch 3
train loss: 0.01951828
id_val loss: 1.87700076


3119it [09:11,  5.66it/s]
347it [01:01,  5.64it/s]



Epoch 4
train loss: 0.02655457
id_val loss: 1.69383542


3119it [09:04,  5.73it/s]
347it [00:58,  5.91it/s]


Epoch 5
train loss: 0.01269672
id_val loss: 1.66902633





In [18]:
import torchvision
from matplotlib.pyplot import imshow
import torchvision.utils
# Grab one image that we are going to test
dataiter = iter(train_loader)
x0, _, label1 = next(dataiter)
dataiter = iter(id_val_loader)

dissimilarity = [[], []]

for i in range(5):
    # Iterate over 5 images and test them with the first image (x0)
    _, x1, label2 = next(dataiter)
    
    output1, output2 = net(x0.cuda(), x1.cuda())

    euclidean_distance = F.pairwise_distance(output1, output2)
    """
    talking on the phone
    """
    for idx in range(BATCH_SIZE):
      # print(label1[idx].item(), label2[idx].item())
      # print(f'Dissimilarity: {euclidean_distance[idx].item():.2f}')
      # print("\n")
      dissimilarity[1 if label1[idx].item() == label2[idx].item() else 0].append(euclidean_distance[idx].item())




In [19]:
import numpy as np
def calc_mean_std(arr):
    arr = np.asarray(arr)
    return arr.mean(), arr.std()


print(calc_mean_std(dissimilarity[0]))
print(calc_mean_std(dissimilarity[1]))

(0.1970403641180107, 0.23923311584199464)
(0.18752764111156872, 0.29753863465416225)


In [20]:
torch.save(net.state_dict(), "contrastive_resnet18_adam.pt")

In [21]:
train_loader = get_train_loader("standard", train_data, batch_size=BATCH_SIZE)
id_val_loader = get_train_loader("standard", id_val_data, batch_size=BATCH_SIZE)
val_loader = get_train_loader("standard", val_data, batch_size=BATCH_SIZE)



In [22]:
x1, label1, _ = next(iter(id_val_loader))

net.eval()
dissimilarity = [[], []]
with torch.no_grad():
    for batch_idx, batch in tqdm(enumerate(train_loader)):
        # Iterate over 5 images and test them with the first image (x0)
        x1, label2, _ = batch
        
        output1, output2 = net(x0.cuda(), x1.cuda())
        try:
            euclidean_distance = F.pairwise_distance(output1, output2)
            
            for idx in range(BATCH_SIZE):
            # print(label1[idx].item(), label2[idx].item())
            # print(f'Dissimilarity: {euclidean_distance[idx].item():.2f}')
                dissimilarity[1 if label1[idx].item() == label2[idx].item() else 0].append(euclidean_distance[idx].item())
        except:
            pass




3119it [05:14,  9.91it/s]


In [23]:
print(calc_mean_std(dissimilarity[0]))
print(calc_mean_std(dissimilarity[1]))

(0.23794319680332573, 0.3468491630828472)
(0.22010003775783324, 0.30061117292724826)


In [24]:
x1, label1, _ = next(iter(val_loader))

net.eval()
dissimilarity = [[], []]
with torch.no_grad():
    for batch_idx, batch in tqdm(enumerate(train_loader)):
        # Iterate over 5 images and test them with the first image (x0)
        x1, label2, _ = batch
        
        output1, output2 = net(x0.cuda(), x1.cuda())
        try:
            euclidean_distance = F.pairwise_distance(output1, output2)
            
            for idx in range(BATCH_SIZE):
            # print(label1[idx].item(), label2[idx].item())
            # print(f'Dissimilarity: {euclidean_distance[idx].item():.2f}')
                dissimilarity[1 if label1[idx].item() == label2[idx].item() else 0].append(euclidean_distance[idx].item())
        except:
            pass

3119it [05:12,  9.97it/s]


In [25]:
print(calc_mean_std(dissimilarity[0]))
print(calc_mean_std(dissimilarity[1]))

(0.22293844498786872, 0.30882887590911995)
(0.23403376401490156, 0.33716413543139645)


In [27]:
print(dissimilarity[1])

[0.2836695611476898, 0.2093307077884674, 0.07745309919118881, 0.14550858736038208, 0.3562085032463074, 0.1933477520942688, 0.8160108327865601, 0.016783298924565315, 0.0752619132399559, 0.047687042504549026, 0.24598334729671478, 0.08313785493373871, 0.12203196436166763, 0.09993915259838104, 0.3189719319343567, 0.030825812369585037, 0.06017804145812988, 0.1326804906129837, 0.8435861468315125, 0.19445259869098663, 0.2557825446128845, 0.13409213721752167, 0.09393972158432007, 0.10094854980707169, 0.10426456481218338, 0.08489972352981567, 0.013394956476986408, 0.05339782312512398, 0.08009808510541916, 0.038799285888671875, 0.977784276008606, 0.0790894478559494, 0.08795386552810669, 0.10678756982088089, 0.38802632689476013, 0.17290011048316956, 0.20285651087760925, 0.13855469226837158, 0.3330739438533783, 0.0549432709813118, 0.6174249649047852, 0.25995174050331116, 0.14573246240615845, 0.13727976381778717, 0.04396975040435791, 0.5167291760444641, 0.03924184292554855, 0.10156222432851791, 0.0