## Imports and initial data processing

In [None]:
from google.colab import files
import os, tqdm, numpy as np, copy
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

In [None]:
torch.cuda.is_available()

False

In [None]:
!nvcc --version
!nvidia-smi


nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Tue_Aug_15_22:02:13_PDT_2023
Cuda compilation tools, release 12.2, V12.2.140
Build cuda_12.2.r12.2/compiler.33191640_0
/bin/bash: line 1: nvidia-smi: command not found


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
%env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

cpu
env: PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True


In [None]:
uploaded = files.upload()
!unzip omniglot.zip >/dev/null

Saving omniglot.zip to omniglot.zip


Get all of the filepaths of the images and their corresponding labels

In [None]:
label = 0
filepath_index = 0

label_to_character = {} # stores the character corresponding to each class label (int)
filepath_to_filepath_index = {} # stores the filepath index (int) corresponding to each filepath -- index used to reduce memory usage
filepath_index_to_filepath = {} # convert back
label_to_filepath_indices = {} # stores the list of image filepath indices for each class label

# loop through all the images_background directories
for root, dirs, files in os.walk('images_background/images_background'):
    if len(files) > 0:

        # record the character and label
        label_to_character[label] = ''.join(root.split('/')[-2:])

        # append all the filepaths to the corresponding label's list
        label_to_filepath_indices[label] = []
        for file_i in files:
            filepath = str(root) + '/' + file_i
            filepath_to_filepath_index[filepath] = filepath_index
            filepath_index_to_filepath[filepath_index] = filepath
            label_to_filepath_indices[label].append(filepath_index)
            filepath_index += 1
    label += 1

# same as above for eval data
for root, dirs, files in os.walk('images_evaluation/images_evaluation'):
    if len(files) > 0:
        label_to_character[label] = ''.join(root.split('/')[-2:])
        label_to_filepath_indices[label] = []
        for file_i in files:
            filepath = str(root) + '/' + file_i
            filepath_to_filepath_index[filepath] = filepath_index
            filepath_index_to_filepath[filepath_index] = filepath
            label_to_filepath_indices[label].append(filepath_index)
            filepath_index += 1
    label += 1

In [None]:
# len(train_labels)

Separate out the test classes and the train classes

In [None]:
# extract labels and characters
dict_items = label_to_character.items()
labels = [dict_item[0] for dict_item in dict_items]
characters = [dict_item[1] for dict_item in dict_items]

# identify which characters are the first of their language
is_first = np.array([character.split('character')[1] == '01' for character in characters])
# get the labels of the last 5 such characters (i.e. identify one character for each of the last 5 languages) for each of test and validation
valid_labels = [labels[i] for i in np.where(is_first)[0][-10:-5]]
test_labels = [labels[i] for i in np.where(is_first)[0][-5:]]
train_labels = [label for label in labels if ((label not in test_labels) and (label not in valid_labels))]

In [None]:
# converts images to tensors (to be used later)
converter = transforms.ToTensor()

Functions to sample random pairs or triplets. Return a tuple, where the first element is the filepath indices and the second element is the class labels

In [None]:
def sample_pairs(num_same, num_diff):
    pairs = [] # store the filepath indices for the image pairs: [[index1, index2], [index3, index4], ...]
    labels = [] # store the labels for the image pairs: [[label1, label2], [label3, label4], ...]

    for i in range(num_same):
        same_label = np.random.choice(train_labels) # choose a random label
        pairs.append(np.random.choice(list(label_to_filepath_indices[same_label]), 2, replace=False)) # extract path indices for a random pair within this label
        labels.append(np.array([same_label, same_label])) # add the target labels

    for i in range(num_diff):
        diff_labels = np.random.choice(list(label_to_filepath_indices.keys()), 2, replace=False) # choose two random labels
        pairs.append([np.random.choice(list(label_to_filepath_indices[label])) for label in diff_labels]) # extract a random path index from each label
        labels.append(diff_labels)

    return pairs, labels

In [None]:
def sample_triplets(num_triplets):
    triplets = [] # store the filepath indices for the image triplets: [[index1, index2, index3], [index4, index5, index6], ...]
    labels = [] # store the labels for the image triplets: [[label1, label2, label3], [label4, label5, label6], ...]

    for i in range(num_triplets):
        [same_label, diff_label] = np.random.choice(train_labels, 2, replace=False) # choose two random labels, one for the same and one for the different
        indices = list(np.random.choice(list(label_to_filepath_indices[same_label]), 2, replace=False)) # extract two random path indices for the same label
        indices.append(np.random.choice(list(label_to_filepath_indices[diff_label]))) # extract one random path index for the different label, and add to the list

        triplets.append(indices) # add this triplet's path indices
        labels.append([same_label, same_label, diff_label]) # add this triplet's target labels

    return triplets, labels

## Set up the dataset and dataloader

In [None]:
class Omniglot_Dataset(Dataset):

    def __init__(self, filepath_sets, label_sets, device, triplets=False):

        self.filepath_sets = filepath_sets
        self.label_sets = label_sets
        self.device = device
        self.triplets = triplets

    def __len__(self):
        return len(self.filepath_sets)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        return {"paths":self.filepath_sets[idx], "labels":self.label_sets[idx]}

def collate_fn(data):

    # get the images from the data
    image_sets = [[Image.open(filepath_index_to_filepath[e['paths'][i]]) for e in data] for i in range(len(data[0]['paths']))]
    # convert the images to tensors
    image_tensors = [torch.cat([converter(image).unsqueeze(0) for image in image_set], dim=0).to(device) for image_set in image_sets]
    # extract the label for each image
    character_indices = [torch.tensor([e['labels'][i] for e in data]).to(device) for i in range(len(data[0]['labels']))]

    # turn on autograd
    [image_tensor.requires_grad_(True) for image_tensor in image_tensors]

    return {"image_tensors":(image_tensors), "labels":(character_indices)}
    # return {"image_tensors":(images_tensor1.to(device), images_tensor2.to(device)), "character_indices":(character_indices1.to(device), character_indices2.to(device))}

## Defining S-subtract and S-concat

In [None]:
class SiameseNet(nn.Module):
    def __init__(self, num_channels, kernel_sizes, pool_sizes, emb_dims, conv_dropout, linear_dropout):

        super().__init__()

        self.num_channels = num_channels
        self.kernel_sizes = kernel_sizes
        self.pool_sizes = pool_sizes
        self.emb_dims = emb_dims
        self.conv_dropout = conv_dropout
        self.linear_dropout = linear_dropout

        size = 105
        for i in range(len(self.kernel_sizes)):
            size -= self.kernel_sizes[i]
            size /= self.pool_sizes[i]
            size = np.ceil(size)
        linear_in = int(size)*int(size)*self.num_channels[-1]


        # add convolution and pooling layers
        self.conv_layers = nn.ModuleList([nn.Conv2d(1, self.num_channels[0], self.kernel_sizes[0])])
        self.batch_norms = nn.ModuleList([nn.BatchNorm2d(self.num_channels[0])])
        self.pool_layers = nn.ModuleList([nn.MaxPool2d(self.pool_sizes[0])])
        for i in range(1, len(self.num_channels)):
            self.conv_layers.append(nn.Conv2d(self.num_channels[i-1], self.num_channels[i], self.kernel_sizes[i]))
            self.batch_norms.append(nn.BatchNorm2d(self.num_channels[i]))
            self.pool_layers.append(nn.MaxPool2d(self.pool_sizes[i]))
        self.dropout2d = nn.Dropout2d(p=self.conv_dropout)

        # add linear layers
        self.linear_layers = nn.ModuleList([nn.Linear(linear_in, self.emb_dims[0])])
        for i in range(1, len(self.emb_dims)):
            self.linear_layers.append(nn.Linear(self.emb_dims[i-1], self.emb_dims[i]))
        self.dropout = nn.Dropout(p=self.linear_dropout)
        self.similarity = nn.Linear(self.emb_dims[-1], 1)


    # performs the embedding on a single branch of the siamese net
    def embed_forward(self, image):

        for conv, norm, pool in zip(self.conv_layers, self.batch_norms, self.pool_layers):
            image = self.dropout2d(image)
            image = conv(image)
            image = F.relu(image)
            image = norm(image)
            image = pool(image)


        image = torch.flatten(image, start_dim=1)

        for linear in self.linear_layers:
            image = self.dropout(image)
            image = linear(image)
            image = F.relu(image)

        return image

    # measures similarity based on two embeddings
    def similarity_forward(self, embeds):

        #Subtract Embeddings
        diff = torch.abs(embeds[0] - embeds[1])
        diff = self.dropout(diff)
        similarity = F.sigmoid(self.similarity(diff).squeeze())

        return similarity


    # performs the embedding on both halves of the data, then measure similarity
    # images is a tuple : (images1, images2)
    def forward(self, images):

        embeds = self.embed_forward(images[0]), self.embed_forward(images[1])
        similarity = self.similarity_forward(embeds)

        return embeds, similarity


    def forward_triplets(self, images):

        embed1, embed2, embed3 = self.embed_forward(images[0]), self.embed_forward(images[1]), self.embed_forward(images[2])
        similarity_same = self.similarity_forward((embed1, embed2))
        similarity_diff = self.similarity_forward((embed1, embed3))

        return (embed1, embed2, embed3), (similarity_same, similarity_diff)


    # computes the loss by performing a forward pass and then comparing to labels
    def compute_loss(self, images, targets):

        embeds, similarity = self.forward_triplets(images)
        target_similarity_same = (targets[0] == targets[1]).float()
        target_similarity_diff = (targets[0] == targets[2]).float()

        loss_same = F.binary_cross_entropy(similarity[0], target_similarity_same)
        loss_diff = F.binary_cross_entropy(similarity[1], target_similarity_diff)

        loss = (loss_same/2 + loss_diff/2)/2
        return embeds, loss

In [None]:
class SiameseNetConcat(SiameseNet):

    def __init__(self, num_channels, kernel_sizes, pool_sizes, emb_dims, conv_dropout, linear_dropout):

        super().__init__(num_channels, kernel_sizes, pool_sizes, emb_dims, conv_dropout, linear_dropout)

        # define new combine layers
        self.combine = nn.Linear(self.emb_dims[-1]*2, self.emb_dims[-1])

    def similarity_forward_permutation(self, concat):

        concat = self.dropout(concat)
        concat = self.combine(concat)
        concat = F.relu(concat)
        concat = self.dropout(concat)

        return F.sigmoid(self.similarity(concat)).squeeze()


    def similarity_forward(self, embeds):

        if self.training:
            similarity_1 = self.similarity_forward_permutation(torch.cat((embeds[0], embeds[1]), 1))
            similarity_2 = self.similarity_forward_permutation(torch.cat((embeds[1], embeds[0]), 1))

            return similarity_1, similarity_2

        else:
            similarity = self.similarity_forward_permutation(torch.cat((embeds[0], embeds[1]), 1))
            return similarity


    def compute_loss(self, images, targets):

        embeds, similarity = self.forward_triplets(images)
        target_similarity_same = (targets[0] == targets[1]).float()
        target_similarity_diff = (targets[0] == targets[2]).float()

        loss_same = F.binary_cross_entropy(similarity[0][0], target_similarity_same)
        loss_same += F.binary_cross_entropy(similarity[0][1], target_similarity_same)
        loss_diff = F.binary_cross_entropy(similarity[1][0], target_similarity_diff)
        loss_diff += F.binary_cross_entropy(similarity[1][1], target_similarity_diff)

        loss = (loss_same/2 + loss_diff/2)/2
        return embeds, loss

## Defining S-multires

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=1, conv_dropout=0):
        super(ConvBlock, self).__init__()
        # print(in_channels)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.padding = padding

        self.dropout = nn.Dropout2d(p=conv_dropout)
        self.conv = nn.Conv2d(self.in_channels, self.out_channels, self.kernel_size, stride, self.padding, bias=False)
        self.bn = nn.BatchNorm2d(self.out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.dropout(x)
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)

        # print("Output Dimension: ", x.shape)
        return x


class MultiResNet(nn.Module):

  def __init__(self, num_conv_layers, conv_dropout, linear_dropout):

      super().__init__()
      self.emb_dims = 1024

      # num_channels = [32, 64, 128, 256, 128, 64, 32, 8]
      num_channels = [8] * num_conv_layers



      self.conv1 = nn.ModuleList([ConvBlock(1, num_channels[0], kernel_size=9, padding=4, conv_dropout=conv_dropout)])
      self.conv2 = nn.ModuleList([ConvBlock(1, num_channels[0], kernel_size=7, padding=3, conv_dropout=conv_dropout)])
      self.conv3 = nn.ModuleList([ConvBlock(1, num_channels[0], kernel_size=5, padding=2, conv_dropout=conv_dropout)])
      self.conv4 = nn.ModuleList([ConvBlock(1, num_channels[0], kernel_size=3, padding=1, conv_dropout=conv_dropout)])

      for i in range(1, len(num_channels)):
        self.conv1.append(ConvBlock(num_channels[i-1]*4, num_channels[i], kernel_size=9, padding=4, conv_dropout=conv_dropout))
        self.conv2.append(ConvBlock(num_channels[i-1]*4, num_channels[i], kernel_size=7, padding=3, conv_dropout=conv_dropout))
        self.conv3.append(ConvBlock(num_channels[i-1]*4, num_channels[i], kernel_size=5, padding=2, conv_dropout=conv_dropout))
        self.conv4.append(ConvBlock(num_channels[i-1]*4, num_channels[i], kernel_size=3, padding=1, conv_dropout=conv_dropout))
      # print(len(self.conv1))

      self.similarity_layer = nn.Sequential(
                            nn.Dropout2d(p=conv_dropout),
                            nn.Conv2d(64,512,3),
                            nn.ReLU(),
                            nn.Flatten(),
                            nn.Dropout(p=linear_dropout),
                            nn.Linear(512*103*103,1))

  def embed_forward(self, image):

      x_cat = torch.clone(image)
      for i in range(len(self.conv1)):
        # print(self.conv1[i].out_channels)

        x1 = self.conv1[i].forward(x_cat)

        x2 = self.conv2[i].forward(x_cat)

        x3 = self.conv3[i].forward(x_cat)

        x4 = self.conv4[i].forward(x_cat)

        x_cat = torch.cat([x1, x2, x3, x4], dim=1)

      # print("Output Concat Shape: ", x_cat.shape)

      return x_cat

  def similarity_forward(self, embeds):

      if self.training:

          #Concatenate Embeddings
          concat1 = torch.cat((embeds[0], embeds[1]), 1)
          similarity1 = F.sigmoid(self.similarity_layer(concat1)).squeeze()

          concat2 = torch.cat((embeds[1], embeds[0]), 1)
          similarity2 = F.sigmoid(self.similarity_layer(concat2)).squeeze()

          return similarity1, similarity2

      else:
          concat = torch.cat((embeds[0], embeds[1]), 1)
          return F.sigmoid(self.similarity_layer(concat)).squeeze()

      return similarity


  def forward(self, images):

      embed1, embed2 = self.embed_forward(images[0]), self.embed_forward(images[1])
      embeds = (embed1, embed2)
      similarity = self.similarity_forward(embeds)

      return embeds, similarity


  def forward_triplets(self, images):

    embed1, embed2, embed3 = self.embed_forward(images[0]), self.embed_forward(images[1]), self.embed_forward(images[2])
    similarity_same = self.similarity_forward((embed1, embed2))
    similarity_diff = self.similarity_forward((embed1, embed3))

    return (embed1, embed2, embed3), (similarity_same, similarity_diff)


  def compute_loss(self, images, targets):

      embeds, similarity = self.forward_triplets(images)
      target_similarity_same = (targets[0] == targets[1]).float()
      target_similarity_diff = (targets[0] == targets[2]).float()

      loss_same = F.binary_cross_entropy(similarity[0][0], target_similarity_same)
      loss_same += F.binary_cross_entropy(similarity[0][1], target_similarity_same)
      loss_diff = F.binary_cross_entropy(similarity[1][0], target_similarity_diff)
      loss_diff += F.binary_cross_entropy(similarity[1][1], target_similarity_diff)

      loss = (loss_same/2 + loss_diff/2)/2
      return embeds, loss

# image1 = torch.randn(1, 1, 105,105)
# image2 = torch.randn(1, 1, 105,105)
# # print(image.shape)
# embed1, embed2 = multires.embed_forward(image1), multires.embed_forward(image1)
# # print(embed.shape)
# similarity = multires.similarity_forward((embed1, embed2))
# print(similarity)

## Train and testing functions

In [None]:
def make_classifications(model, num_classes, num_trials, batch_size, labels):

    with torch.no_grad():

        acc, loss = 0,0

        for batch in range(int(np.ceil(num_trials/batch_size))):
            batch_size = min((batch + 1) * batch_size, num_trials) - batch * batch_size

            # store whether each one-shot trial is correct, and the loss incurred
            correct = torch.zeros(batch_size, dtype=torch.bool).to(device)
            loss = 0

            # get validation labels
            shuffled_labels = np.array([np.random.permutation(labels) for i in range(batch_size)])

            # the true similarity (last element is always match)
            true_similarity = torch.zeros((batch_size, num_classes)).to(device)
            true_similarity[:,-1] = 1

            exemplar_tensor = []
            inference_tensor = []

            i = 0
            while i < batch_size:

                ### Select 1 exemplar from each of 5 random characters; select a 2nd of the last character to classify ###
                # get one random filepath index for each of the first 4 validation labels
                exemplar_filepath_indices = [np.random.choice(label_to_filepath_indices[label], replace=False) for label in shuffled_labels[i,:-1]]
                # get two random filepath indices for the remaining validation label
                [final_exemplar_filepath_indices, inference_filepath_index] = np.random.choice(label_to_filepath_indices[shuffled_labels[i,-1]], 2, replace=False)
                # add one of the filepath indices for the last validation label to the other list
                exemplar_filepath_indices.append(final_exemplar_filepath_indices)

                ### Extract the images ###
                # extract the exemplar images and concatenate along batch dimensino
                exemplar_images = torch.cat([converter(Image.open(filepath_index_to_filepath[index])).unsqueeze(0) for index in exemplar_filepath_indices], dim=0).to(device)
                # extract the inference image and replicate along batch dimension to be the same size as exemplar_images
                inference_image = converter(Image.open(filepath_index_to_filepath[inference_filepath_index])).expand(exemplar_images.size()).to(device)

                exemplar_tensor.append(exemplar_images)
                inference_tensor.append(inference_image)

                i += 1

            # prepare tensors
            exemplar_tensor = torch.cat(exemplar_tensor).to(device)
            inference_tensor = torch.cat(inference_tensor).to(device)

            ### Compute the predicted similarities and rate the performance ###
            _, similarities = model.forward((exemplar_tensor, inference_tensor))

            similarities = similarities.view(batch_size, num_classes)
            acc += sum([torch.argmax(similarities[row,:]) == num_classes-1 for row in range(batch_size)])/num_trials # correct label is always the last element
            loss += F.binary_cross_entropy(similarities, true_similarity) * batch_size * num_classes / num_trials

    return acc, loss # accuracy and average loss

In [None]:
def train(model, data_loader, model_file, num_epochs=200, lr=1e-3, valid_params=(5,320,32), concat_lr_ratio=1):

    model.eval()
    (num_classes, num_trials, valid_batch_size) = valid_params
    valid_acc, valid_loss = make_classifications(model, num_classes, num_trials, valid_batch_size, valid_labels)
    print(' %d-way, one-shot accuracy: %.1f%%. Loss: %f'%(num_classes, valid_acc*100, valid_loss))

    concat_var_names = ['combine']
    # setup
    concat_named_params = list(filter(lambda kv: any(key in kv[0] for key in concat_var_names), model.named_parameters()))
    not_concat_named_params = list(filter(lambda kv: not any(key in kv[0] for key in concat_var_names), model.named_parameters()))
    concat_params = [e[1] for e in concat_named_params]
    not_concat_params = [e[1] for e in not_concat_named_params]
    optimizer = torch.optim.AdamW([
        {'params': not_concat_params},
        {
            'params': concat_params,
            'lr': lr * concat_lr_ratio
        }
    ], lr = lr)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    best_valid_acc = 0
    best_model = model
    iters_since_best = 0

    clip = 50.0
    for epoch in tqdm.trange(num_epochs, desc="training", unit="epoch"):
        with tqdm.tqdm(data_loader, desc=f"epoch {epoch + 1}", unit="batch", total=len(data_loader), position=0, leave=True) as batch_iterator:
            model.train()
            total_loss = 0.0
            for i, batch_data in enumerate(batch_iterator, start=1):
                images, target = batch_data["image_tensors"], batch_data["labels"]
                optimizer.zero_grad()
                embeds, loss = model.compute_loss(images, target)
                total_loss += loss.item()
                loss.backward()

                # Gradient clipping before taking the step
                _ = nn.utils.clip_grad_norm_(model.parameters(), clip)
                optimizer.step()

                batch_iterator.set_postfix(mean_loss=total_loss / i, current_loss=loss.item())

            # compute validation accuracy and loss
            model.eval()
            valid_acc, valid_loss = make_classifications(model, num_classes, num_trials, valid_batch_size, valid_labels)
            print(' %d-way, one-shot validation accuracy: %.1f%%. Validation loss: %f'%(num_classes, valid_acc*100, valid_loss))

            if valid_acc > best_valid_acc:
                best_valid_acc = valid_acc
                iters_since_best = 0
                best_model = copy.deepcopy(model)
            else:
                iters_since_best += 1

            if iters_since_best >= 30:
                break


    # Perform final test and save the model after training
    print('Final best validation accuracy: %.1f%%.'%(best_valid_acc*100))
    test_acc, test_loss = make_classifications(best_model, num_classes, num_trials, valid_batch_size, test_labels)
    print(' %d-way, one-shot test accuracy: %.1f%%. Test loss: %f'%(num_classes, test_acc*100, test_loss))
    torch.save(best_model.state_dict(), model_file)
    return best_model

## Training of Subtract and Concat Models

In [None]:
torch.cuda.empty_cache() # clears stuff from gpu

# data params
triplets = True
num_triplets = int(1e5)
num_same_pairs, num_diff_pairs = int(1e4), int(1e4)

# sample the data
if triplets:
    images, indices = sample_triplets(num_triplets)
else:
    images, indices = sample_pairs(num_same_pairs, num_diff_pairs)

# set up dataset and dataloader
dataset = Omniglot_Dataset(images, indices, device, triplets=triplets)
data_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

In [None]:
# model params
num_channels = [64, 128, 128, 256]
kernel_sizes = [7, 5, 3, 3]
pool_sizes = [2, 2, 2, 2]
emb_dims = [4096, 1024]
conv_dropout, linear_dropout = 0.2, 0.5

# model = SiameseNet(num_channels, kernel_sizes, pool_sizes, emb_dims, conv_dropout, linear_dropout).to(device)
model = SiameseNetConcat(num_channels, kernel_sizes, pool_sizes, emb_dims, conv_dropout, linear_dropout).to(device)

In [None]:
# print model details
if triplets:
    print("Triplets:", num_triplets)
else:
    print("Pairs:", num_same_pairs, num_diff_pairs)
print(['S-Subtract','S-Concat'][int(isinstance(model, SiameseNetConcat))])

# calculate number of parameters
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print(params, 'parameters')

#train
best_model = train(model, data_loader, "siamese_model.pt", lr=1e-4)

Triplets: 100000
S-Concat
23727873 parameters
 5-way, one-shot accuracy: 20.3%. Loss: 0.350669


epoch 1: 100%|██████████| 3125/3125 [08:00<00:00,  6.50batch/s, current_loss=0.444, mean_loss=0.576]
training:   0%|          | 1/200 [08:02<26:39:19, 482.21s/epoch]

 5-way, one-shot validation accuracy: 71.6%. Validation loss: 0.379999


epoch 2: 100%|██████████| 3125/3125 [08:02<00:00,  6.48batch/s, current_loss=0.434, mean_loss=0.468]
training:   1%|          | 2/200 [16:05<26:34:20, 483.14s/epoch]

 5-way, one-shot validation accuracy: 82.2%. Validation loss: 0.186675


epoch 3: 100%|██████████| 3125/3125 [08:02<00:00,  6.48batch/s, current_loss=0.472, mean_loss=0.425]
training:   2%|▏         | 3/200 [24:09<26:27:29, 483.50s/epoch]

 5-way, one-shot validation accuracy: 88.1%. Validation loss: 0.324264


epoch 4: 100%|██████████| 3125/3125 [08:01<00:00,  6.48batch/s, current_loss=0.271, mean_loss=0.401]
training:   2%|▏         | 4/200 [32:13<26:19:26, 483.50s/epoch]

 5-way, one-shot validation accuracy: 89.4%. Validation loss: 0.306855


epoch 5: 100%|██████████| 3125/3125 [08:01<00:00,  6.50batch/s, current_loss=0.323, mean_loss=0.385]
training:   2%|▎         | 5/200 [40:16<26:10:19, 483.18s/epoch]

 5-way, one-shot validation accuracy: 87.2%. Validation loss: 0.268889


epoch 6: 100%|██████████| 3125/3125 [08:01<00:00,  6.50batch/s, current_loss=0.514, mean_loss=0.373]
training:   3%|▎         | 6/200 [48:18<26:01:42, 483.00s/epoch]

 5-way, one-shot validation accuracy: 89.1%. Validation loss: 0.320250


epoch 7: 100%|██████████| 3125/3125 [08:00<00:00,  6.50batch/s, current_loss=0.377, mean_loss=0.364]
training:   4%|▎         | 7/200 [56:21<25:52:59, 482.80s/epoch]

 5-way, one-shot validation accuracy: 85.0%. Validation loss: 0.335085


epoch 8: 100%|██████████| 3125/3125 [08:01<00:00,  6.49batch/s, current_loss=0.342, mean_loss=0.354]
training:   4%|▍         | 8/200 [1:04:23<25:44:51, 482.77s/epoch]

 5-way, one-shot validation accuracy: 87.8%. Validation loss: 0.194210


epoch 9: 100%|██████████| 3125/3125 [08:01<00:00,  6.50batch/s, current_loss=0.333, mean_loss=0.347]
training:   4%|▍         | 9/200 [1:12:26<25:36:45, 482.75s/epoch]

 5-way, one-shot validation accuracy: 92.2%. Validation loss: 0.331463


epoch 10: 100%|██████████| 3125/3125 [08:02<00:00,  6.47batch/s, current_loss=0.355, mean_loss=0.339]
training:   5%|▌         | 10/200 [1:20:30<25:30:13, 483.23s/epoch]

 5-way, one-shot validation accuracy: 90.3%. Validation loss: 0.241578


epoch 11: 100%|██████████| 3125/3125 [08:01<00:00,  6.49batch/s, current_loss=0.259, mean_loss=0.334]
training:   6%|▌         | 11/200 [1:28:33<25:21:53, 483.14s/epoch]

 5-way, one-shot validation accuracy: 90.9%. Validation loss: 0.261379


epoch 12: 100%|██████████| 3125/3125 [08:01<00:00,  6.49batch/s, current_loss=0.326, mean_loss=0.328]
training:   6%|▌         | 12/200 [1:36:36<25:13:56, 483.17s/epoch]

 5-way, one-shot validation accuracy: 94.1%. Validation loss: 0.279219


epoch 13: 100%|██████████| 3125/3125 [08:02<00:00,  6.48batch/s, current_loss=0.347, mean_loss=0.322]
training:   6%|▋         | 13/200 [1:44:40<25:06:36, 483.40s/epoch]

 5-way, one-shot validation accuracy: 92.8%. Validation loss: 0.225366


epoch 14: 100%|██████████| 3125/3125 [08:02<00:00,  6.48batch/s, current_loss=0.344, mean_loss=0.32]
training:   7%|▋         | 14/200 [1:52:44<24:59:10, 483.60s/epoch]

 5-way, one-shot validation accuracy: 91.9%. Validation loss: 0.321050


epoch 15: 100%|██████████| 3125/3125 [08:02<00:00,  6.48batch/s, current_loss=0.351, mean_loss=0.316]
training:   8%|▊         | 15/200 [2:00:48<24:51:29, 483.73s/epoch]

 5-way, one-shot validation accuracy: 96.9%. Validation loss: 0.222482


epoch 16: 100%|██████████| 3125/3125 [08:01<00:00,  6.49batch/s, current_loss=0.352, mean_loss=0.314]
training:   8%|▊         | 16/200 [2:08:52<24:43:00, 483.59s/epoch]

 5-way, one-shot validation accuracy: 95.6%. Validation loss: 0.221295


epoch 17: 100%|██████████| 3125/3125 [08:02<00:00,  6.48batch/s, current_loss=0.162, mean_loss=0.31]
training:   8%|▊         | 17/200 [2:16:55<24:34:55, 483.58s/epoch]

 5-way, one-shot validation accuracy: 95.3%. Validation loss: 0.312506


epoch 18: 100%|██████████| 3125/3125 [08:05<00:00,  6.43batch/s, current_loss=0.403, mean_loss=0.307]
training:   9%|▉         | 18/200 [2:25:03<24:30:14, 484.70s/epoch]

 5-way, one-shot validation accuracy: 94.7%. Validation loss: 0.340916


epoch 19: 100%|██████████| 3125/3125 [08:06<00:00,  6.43batch/s, current_loss=0.184, mean_loss=0.305]
training:  10%|▉         | 19/200 [2:33:10<24:24:49, 485.58s/epoch]

 5-way, one-shot validation accuracy: 96.2%. Validation loss: 0.158146


epoch 20: 100%|██████████| 3125/3125 [08:05<00:00,  6.43batch/s, current_loss=0.404, mean_loss=0.302]
training:  10%|█         | 20/200 [2:41:18<24:18:29, 486.17s/epoch]

 5-way, one-shot validation accuracy: 95.9%. Validation loss: 0.396991


epoch 21: 100%|██████████| 3125/3125 [08:04<00:00,  6.44batch/s, current_loss=0.392, mean_loss=0.301]
training:  10%|█         | 21/200 [2:49:24<24:10:44, 486.28s/epoch]

 5-way, one-shot validation accuracy: 96.6%. Validation loss: 0.180501


epoch 22: 100%|██████████| 3125/3125 [08:05<00:00,  6.44batch/s, current_loss=0.358, mean_loss=0.298]
training:  11%|█         | 22/200 [2:57:31<24:03:14, 486.49s/epoch]

 5-way, one-shot validation accuracy: 97.5%. Validation loss: 0.149447


epoch 23: 100%|██████████| 3125/3125 [08:06<00:00,  6.43batch/s, current_loss=0.236, mean_loss=0.298]
training:  12%|█▏        | 23/200 [3:05:39<23:56:10, 486.84s/epoch]

 5-way, one-shot validation accuracy: 92.8%. Validation loss: 0.190563


epoch 24: 100%|██████████| 3125/3125 [08:06<00:00,  6.42batch/s, current_loss=0.227, mean_loss=0.297]
training:  12%|█▏        | 24/200 [3:13:47<23:49:21, 487.28s/epoch]

 5-way, one-shot validation accuracy: 95.9%. Validation loss: 0.392059


epoch 25: 100%|██████████| 3125/3125 [08:06<00:00,  6.43batch/s, current_loss=0.422, mean_loss=0.294]
training:  12%|█▎        | 25/200 [3:21:55<23:41:49, 487.48s/epoch]

 5-way, one-shot validation accuracy: 98.4%. Validation loss: 0.134893


epoch 26: 100%|██████████| 3125/3125 [08:07<00:00,  6.41batch/s, current_loss=0.359, mean_loss=0.291]
training:  13%|█▎        | 26/200 [3:30:04<23:35:15, 488.02s/epoch]

 5-way, one-shot validation accuracy: 98.4%. Validation loss: 0.132451


epoch 27: 100%|██████████| 3125/3125 [08:07<00:00,  6.41batch/s, current_loss=0.333, mean_loss=0.291]
training:  14%|█▎        | 27/200 [3:38:13<23:27:58, 488.31s/epoch]

 5-way, one-shot validation accuracy: 99.1%. Validation loss: 0.158971


epoch 28: 100%|██████████| 3125/3125 [08:06<00:00,  6.42batch/s, current_loss=0.202, mean_loss=0.289]
training:  14%|█▍        | 28/200 [3:46:22<23:19:57, 488.36s/epoch]

 5-way, one-shot validation accuracy: 98.1%. Validation loss: 0.169741


epoch 29: 100%|██████████| 3125/3125 [08:06<00:00,  6.42batch/s, current_loss=0.346, mean_loss=0.291]
training:  14%|█▍        | 29/200 [3:54:30<23:11:56, 488.40s/epoch]

 5-way, one-shot validation accuracy: 97.2%. Validation loss: 0.174760


epoch 30: 100%|██████████| 3125/3125 [08:06<00:00,  6.43batch/s, current_loss=0.341, mean_loss=0.287]
training:  15%|█▌        | 30/200 [4:02:38<23:03:24, 488.26s/epoch]

 5-way, one-shot validation accuracy: 96.9%. Validation loss: 0.133090


epoch 31: 100%|██████████| 3125/3125 [08:07<00:00,  6.41batch/s, current_loss=0.231, mean_loss=0.286]
training:  16%|█▌        | 31/200 [4:10:47<22:55:38, 488.39s/epoch]

 5-way, one-shot validation accuracy: 95.6%. Validation loss: 0.224862


epoch 32: 100%|██████████| 3125/3125 [08:06<00:00,  6.42batch/s, current_loss=0.282, mean_loss=0.284]
training:  16%|█▌        | 32/200 [4:18:55<22:47:21, 488.34s/epoch]

 5-way, one-shot validation accuracy: 96.6%. Validation loss: 0.098161


epoch 33: 100%|██████████| 3125/3125 [08:06<00:00,  6.42batch/s, current_loss=0.245, mean_loss=0.287]
training:  16%|█▋        | 33/200 [4:27:03<22:38:58, 488.25s/epoch]

 5-way, one-shot validation accuracy: 93.4%. Validation loss: 0.198669


epoch 34: 100%|██████████| 3125/3125 [08:06<00:00,  6.42batch/s, current_loss=0.253, mean_loss=0.285]
training:  17%|█▋        | 34/200 [4:35:12<22:31:01, 488.32s/epoch]

 5-way, one-shot validation accuracy: 97.2%. Validation loss: 0.094090


epoch 35: 100%|██████████| 3125/3125 [08:06<00:00,  6.43batch/s, current_loss=0.306, mean_loss=0.282]
training:  18%|█▊        | 35/200 [4:43:20<22:22:32, 488.20s/epoch]

 5-way, one-shot validation accuracy: 97.2%. Validation loss: 0.120643


epoch 36: 100%|██████████| 3125/3125 [08:06<00:00,  6.42batch/s, current_loss=0.272, mean_loss=0.282]
training:  18%|█▊        | 36/200 [4:51:28<22:14:35, 488.27s/epoch]

 5-way, one-shot validation accuracy: 96.6%. Validation loss: 0.164893


epoch 37: 100%|██████████| 3125/3125 [08:07<00:00,  6.41batch/s, current_loss=0.29, mean_loss=0.281]
training:  18%|█▊        | 37/200 [4:59:37<22:07:01, 488.47s/epoch]

 5-way, one-shot validation accuracy: 97.8%. Validation loss: 0.195291


epoch 38: 100%|██████████| 3125/3125 [08:07<00:00,  6.42batch/s, current_loss=0.201, mean_loss=0.281]
training:  19%|█▉        | 38/200 [5:07:46<21:58:59, 488.52s/epoch]

 5-way, one-shot validation accuracy: 98.1%. Validation loss: 0.099972


epoch 39: 100%|██████████| 3125/3125 [08:07<00:00,  6.41batch/s, current_loss=0.305, mean_loss=0.281]
training:  20%|█▉        | 39/200 [5:15:55<21:51:08, 488.62s/epoch]

 5-way, one-shot validation accuracy: 95.3%. Validation loss: 0.312684


epoch 40: 100%|██████████| 3125/3125 [08:06<00:00,  6.42batch/s, current_loss=0.314, mean_loss=0.279]
training:  20%|██        | 40/200 [5:24:03<21:42:33, 488.46s/epoch]

 5-way, one-shot validation accuracy: 95.3%. Validation loss: 0.150551


epoch 41: 100%|██████████| 3125/3125 [08:06<00:00,  6.42batch/s, current_loss=0.249, mean_loss=0.278]
training:  20%|██        | 41/200 [5:32:11<21:34:26, 488.47s/epoch]

 5-way, one-shot validation accuracy: 96.6%. Validation loss: 0.139179


epoch 42: 100%|██████████| 3125/3125 [08:06<00:00,  6.42batch/s, current_loss=0.291, mean_loss=0.278]
training:  21%|██        | 42/200 [5:40:19<21:26:08, 488.41s/epoch]

 5-way, one-shot validation accuracy: 95.9%. Validation loss: 0.179681


epoch 43: 100%|██████████| 3125/3125 [08:06<00:00,  6.42batch/s, current_loss=0.304, mean_loss=0.278]
training:  22%|██▏       | 43/200 [5:48:27<21:17:39, 488.28s/epoch]

 5-way, one-shot validation accuracy: 97.2%. Validation loss: 0.156739


epoch 44: 100%|██████████| 3125/3125 [08:05<00:00,  6.43batch/s, current_loss=0.359, mean_loss=0.278]
training:  22%|██▏       | 44/200 [5:56:35<21:08:55, 488.05s/epoch]

 5-way, one-shot validation accuracy: 96.9%. Validation loss: 0.129933


epoch 45: 100%|██████████| 3125/3125 [08:05<00:00,  6.43batch/s, current_loss=0.186, mean_loss=0.276]
training:  22%|██▎       | 45/200 [6:04:42<21:00:21, 487.88s/epoch]

 5-way, one-shot validation accuracy: 97.5%. Validation loss: 0.165898


epoch 46: 100%|██████████| 3125/3125 [08:05<00:00,  6.43batch/s, current_loss=0.214, mean_loss=0.277]
training:  23%|██▎       | 46/200 [6:12:50<20:51:52, 487.74s/epoch]

 5-way, one-shot validation accuracy: 98.8%. Validation loss: 0.075916


epoch 47: 100%|██████████| 3125/3125 [08:05<00:00,  6.44batch/s, current_loss=0.288, mean_loss=0.278]
training:  24%|██▎       | 47/200 [6:20:57<20:43:15, 487.55s/epoch]

 5-way, one-shot validation accuracy: 97.5%. Validation loss: 0.106281


epoch 48: 100%|██████████| 3125/3125 [08:06<00:00,  6.42batch/s, current_loss=0.245, mean_loss=0.276]
training:  24%|██▍       | 48/200 [6:29:05<20:35:29, 487.69s/epoch]

 5-way, one-shot validation accuracy: 96.6%. Validation loss: 0.072471


epoch 49: 100%|██████████| 3125/3125 [08:06<00:00,  6.43batch/s, current_loss=0.224, mean_loss=0.274]
training:  24%|██▍       | 49/200 [6:37:13<20:27:27, 487.73s/epoch]

 5-way, one-shot validation accuracy: 97.5%. Validation loss: 0.137844


epoch 50: 100%|██████████| 3125/3125 [08:05<00:00,  6.44batch/s, current_loss=0.263, mean_loss=0.274]
training:  25%|██▌       | 50/200 [6:45:20<20:18:45, 487.50s/epoch]

 5-way, one-shot validation accuracy: 92.8%. Validation loss: 0.192448


epoch 51: 100%|██████████| 3125/3125 [08:05<00:00,  6.43batch/s, current_loss=0.235, mean_loss=0.274]
training:  26%|██▌       | 51/200 [6:53:27<20:10:26, 487.42s/epoch]

 5-way, one-shot validation accuracy: 97.8%. Validation loss: 0.098805


epoch 52: 100%|██████████| 3125/3125 [08:06<00:00,  6.43batch/s, current_loss=0.3, mean_loss=0.273]
training:  26%|██▌       | 52/200 [7:01:35<20:02:26, 487.47s/epoch]

 5-way, one-shot validation accuracy: 97.5%. Validation loss: 0.065325


epoch 53: 100%|██████████| 3125/3125 [08:05<00:00,  6.44batch/s, current_loss=0.326, mean_loss=0.273]
training:  26%|██▋       | 53/200 [7:09:42<19:54:03, 487.37s/epoch]

 5-way, one-shot validation accuracy: 97.8%. Validation loss: 0.092698


epoch 54: 100%|██████████| 3125/3125 [08:06<00:00,  6.42batch/s, current_loss=0.236, mean_loss=0.273]
training:  27%|██▋       | 54/200 [7:17:50<19:46:22, 487.55s/epoch]

 5-way, one-shot validation accuracy: 98.1%. Validation loss: 0.116222


epoch 55: 100%|██████████| 3125/3125 [08:05<00:00,  6.44batch/s, current_loss=0.211, mean_loss=0.274]
training:  28%|██▊       | 55/200 [7:25:57<19:37:45, 487.35s/epoch]

 5-way, one-shot validation accuracy: 96.6%. Validation loss: 0.159826


epoch 56: 100%|██████████| 3125/3125 [08:05<00:00,  6.43batch/s, current_loss=0.262, mean_loss=0.273]
training:  28%|██▊       | 56/200 [7:34:04<19:29:33, 487.32s/epoch]

 5-way, one-shot validation accuracy: 95.3%. Validation loss: 0.114390


epoch 57: 100%|██████████| 3125/3125 [08:06<00:00,  6.43batch/s, current_loss=0.283, mean_loss=0.272]
training:  28%|██▊       | 56/200 [7:42:11<19:48:30, 495.21s/epoch]

 5-way, one-shot validation accuracy: 96.6%. Validation loss: 0.078911
Final best validation accuracy: 99.1%.





 5-way, one-shot test accuracy: 90.3%. Test loss: 0.310094


In [None]:
# from google.colab import files
# files.download('siamese_model.pt')

## S-Multires training

In [None]:
torch.cuda.empty_cache() # clears stuff from gpu

# model params
triplets = True
num_triplets = int(1e4)
num_same_pairs, num_diff_pairs = int(1e4), int(1e4)

if triplets:
    triplets, indices = sample_triplets(num_triplets)
    dataset = Omniglot_Dataset(triplets, indices, device, triplets=triplets)
else:
    # sampling pairs
    pairs, indices = sample_pairs(num_same_pairs, num_diff_pairs)
    # set up dataset
    dataset = Omniglot_Dataset(pairs, indices, device, triplets=triplets)

In [None]:
# train params
max_num_epochs = 200
batch_size = 64

# model params
num_conv_layers = 8
conv_dropout, linear_dropout = 0.2, 0.5

# accuracy test params
num_classes = 5
num_tests = 320

# set up dataloader and model
data_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
multires = MultiResNet(num_conv_layers, conv_dropout, linear_dropout).to(device)

# calculate number of parameters
model_parameters = filter(lambda p: p.requires_grad, multires.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print(params, 'parameters')

6022945 parameters


In [None]:
#Testing cell - delete later
best_model = train(multires, data_loader, "siamese_model_multires.pt", lr=1e-5)