<a href="https://colab.research.google.com/github/wileyw/DeepLearningDemos/blob/master/TwinNetwork/twin_network.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Twin Network

In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import os

import torch
import torch.nn as nn
import torch.nn.functional as F


In [None]:
# !unzip drive/MyDrive/fruits-360.zip -d drive/MyDrive
# https://towardsdatascience.com/siamese-networks-line-by-line-explanation-for-beginners-55b8be1d2fc6

## Data Preprocessing

In [None]:
base_dir = r'/content/drive/MyDrive/fruits-360/Training/'
train_test_split = 0.7
no_of_files_in_each_class = 10

#Read all the folders in the directory
folder_list = os.listdir(base_dir)
print( len(folder_list), "categories found in the dataset")

#Declare training array
cat_list = []
x = []
y = []
y_label = 0

#Using just no_of_files_in_each_class images per category
for folder_name in folder_list:
    files_list = os.listdir(os.path.join(base_dir, folder_name))
    if len(files_list) < no_of_files_in_each_class:
      print(f"skipping {folder_name}")
      continue
    temp=[]
    for file_name in files_list[:no_of_files_in_each_class]:
        temp.append(len(x))
        x.append(np.asarray(Image.open(os.path.join(base_dir, folder_name, file_name)).convert('RGB').resize((100, 100))))
        y.append(y_label)
    y_label+=1
    cat_list.append(temp)

cat_list = np.asarray(cat_list)
x = np.asarray(x)/255.0
y = np.asarray(y)
print('X, Y shape',x.shape, y.shape, cat_list.shape)

In [None]:
# Adapt x input dimension to PyTorch format.
x = x.transpose(0, 3, 1, 2)
print('X, Y shape',x.shape, y.shape, cat_list.shape)

## Train Test Split

In [None]:
train_size = int(len(folder_list)*train_test_split)
test_size = len(folder_list) - train_size
print(train_size, 'classes for training and', test_size, ' classes for testing')

train_files = train_size * no_of_files_in_each_class

#Training Split
x_train = x[:train_files]
y_train = y[:train_files]
cat_train = cat_list[:train_size]

#Validation Split
x_val = x[train_files:]
y_val = y[train_files:]
cat_test = cat_list[train_size:]

print('X&Y shape of training data :',x_train.shape, 'and', y_train.shape, cat_train.shape)
print('X&Y shape of testing data :' , x_val.shape, 'and', y_val.shape, cat_test.shape)

## Generating Batch

In [None]:
def get_batch(batch_size=64):
    
    temp_x = x_train
    temp_cat_list = cat_train
    start=0
    end=train_size
    batch_x=[]
        
    batch_y = np.zeros(batch_size)
    batch_y[int(batch_size/2):] = 1
    np.random.shuffle(batch_y)
    
    class_list = np.random.randint(start, end, batch_size) 
    batch_x.append(np.zeros((batch_size, 3, 100, 100)))
    batch_x.append(np.zeros((batch_size, 3, 100, 100)))

    for i in range(0, batch_size):
        batch_x[0][i] = temp_x[np.random.choice(temp_cat_list[class_list[i]])]
        #If train_y has 0 pick from the same class, else pick from any other class
        if batch_y[i]==0:
            batch_x[1][i] = temp_x[np.random.choice(temp_cat_list[class_list[i]])]

        else:
            temp_list = np.append(temp_cat_list[:class_list[i]].flatten(), temp_cat_list[class_list[i]+1:].flatten())
            batch_x[1][i] = temp_x[np.random.choice(temp_list)]
            
    return(batch_x, batch_y)

## Twin Network

In [None]:
#Building a sequential model
class CnnNetwork(nn.Module):
    def __init__(self):
        super(CnnNetwork, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square convolution
        # kernel
        self.conv1 = nn.Conv2d(3, 64, 10)
        self.conv2 = nn.Conv2d(64, 128, 7)
        self.conv3 = nn.Conv2d(128, 128, 4)
        self.conv4 = nn.Conv2d(128, 256, 4)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(256 * 4, 4096)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), 2)
        # If the size is a square, you can specify with a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = F.max_pool2d(F.relu(self.conv3(x)), 2)
        x = F.max_pool2d(F.relu(self.conv4(x)), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.sigmoid(self.fc1(x))
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features


class TwinNetwork(nn.Module):
    def __init__(self):
        super(TwinNetwork, self).__init__()
        self.cnn = CnnNetwork()
        self.fc1 = nn.Linear(4096, 1)

    def forward(self, left, right):
        x = self.cnn(left)
        y = self.cnn(right)
        diff = torch.abs(x - y)
        z = F.sigmoid(self.fc1(diff))
        return z


## N-way one-shot Learning

In [None]:
def nway_one_shot(n_way, n_val):
    
    temp_x = x_val
    temp_cat_list = cat_test
    batch_x=[]
    x_0_choice=[]
    n_correct = 0
   
    class_list = np.random.randint(train_size+1, len(cat_list)-1, n_val)

    for i in class_list:  
        j = np.random.choice(cat_list[i])
        temp=[]
        temp.append(np.zeros((n_way, 3, 100, 100)))
        temp.append(np.zeros((n_way, 3, 100, 100)))
        for k in range(0, n_way):
            temp[0][k] = x[j]
            # 2 is arbitrary here, as 0 is the default number when all numbers
            # are equal, which leads to wrong conclusions.
            if k==2:
                temp[1][k] = x[np.random.choice(cat_list[i])]
            else:
                temp[1][k] = x[np.random.choice(np.append(cat_list[:i].flatten(), cat_list[i+1:].flatten()))]

        result = twin_net(torch.Tensor(temp[0]).cuda(), torch.Tensor(temp[1]).cuda())
        result = result.flatten().tolist()
        result_index = result.index(min(result))
        if result_index == 2:
            n_correct = n_correct + 1
    print(n_correct, "correctly classified among", n_val)
    accuracy = (n_correct*100)/n_val
    return accuracy

In [None]:
# Tools to display batch data graphically
def display_batch(batch_x, batch_y, batch_size=64):
  num = int(batch_size ** 0.5)
  combined_left = np.zeros((num*100, num*100, 3))
  combined_right = np.zeros((num*100, num*100, 3))
  count = 0
  for i in range(num):
    for j in range(num):
      left_image = batch_x[0][count].transpose(1, 2, 0)
      right_image = batch_x[1][count].transpose(1, 2, 0)
      combined_left[i*100:(i+1)*100, j*100:(j+1)*100, :] = left_image
      combined_right[i*100:(i+1)*100, j*100:(j+1)*100, :] = right_image
      count += 1
  plt.imshow(combined_left)
  plt.show()
  plt.imshow(combined_right)
  plt.show()
  print("batch_y is")
  print(np.reshape(batch_y, (-1, num)))

## Training the Model

In [None]:
# We started learning_rate at 0.0006, but it was too coarse.
learning_rate = 0.0001
twin_net = TwinNetwork().cuda()
optimizer = torch.optim.Adam(twin_net.parameters(), lr=learning_rate)

loss = nn.BCELoss()

epochs = 30000
n_way = 20
n_val = 100
batch_size = 64

loss_list=[]
accuracy_list=[]
for epoch in range(epochs):
    if epoch == 0:
        accuracy = nway_one_shot(n_way, n_val)
        print('Accuracy as of', epoch, 'epochs:', accuracy)
    batch_x, batch_y = get_batch(batch_size)
    # display_batch(batch_x, batch_y)
    # zero the parameter gradients
    optimizer.zero_grad()

    # forward + backward + optimize
    twin_outputs = twin_net(torch.Tensor(batch_x[0]).cuda(), torch.Tensor(batch_x[1]).cuda())
    outputs = loss(twin_outputs, torch.Tensor(batch_y).reshape(64, 1).cuda())
    outputs.backward()
    optimizer.step()

    # print('Epoch:', epoch, ', Loss:',outputs)
    loss_list.append(outputs.item())
    # print statistics
    if epoch % 250 == 0:
        print("=============================================")
        accuracy = nway_one_shot(n_way, n_val)
        accuracy_list.append((epoch, accuracy))
        print('Accuracy as of', epoch, 'epochs:', accuracy)
        print('Epoch:', epoch, ', Loss:',np.mean(loss_list[-250:]))
        print("=============================================")
        if(accuracy>90):
            print("Achieved more than 90% Accuracy")
            break