In [4]:
from fastai.vision.all import *
from kornia import rgb_to_grayscale
import torch.nn as nn
import warnings
import random, textwrap
import os, glob
import numpy as np
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import skimage
from skimage.util import random_noise
from skimage import util
from skimage import exposure
from scipy import ndimage
from datetime import datetime
import torch
import torch.utils.model_zoo
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
from pathlib import Path
import collections
import seaborn as sn

warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
defaults.device = torch.device('cuda')

seed = 30
np.random.seed(seed)
torch.manual_seed(seed)

sn.set(style="darkgrid")

# stim_path = Path(r'D:\Andrea_NN\stimuli\no_transf')
stim_path = Path(r'D:\Andrea_NN\stimuli\samediff')

epochs = 1000
cycles = 1
batch_sz = 24
lr_min = 1e-3
weight_decay = 1e-3
w_dropout_1 = 0.8
w_dropout_2 = 0.8
fov_noise = False


class SiameseNetEncoderFB(nn.Module):
    def __init__(self):
        super(SiameseNetEncoderFB, self).__init__()

        # V1 layers
        self.V1_p = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=7 // 2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
        )

        # V2 layers
        self.V2_p = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=3 // 2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
        )

        # V4 layers
        self.V4_p = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=3 // 2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
        )

        # IT layers
        self.IT_p = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=3 // 2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
        )

        # V1 layers
        self.V1_f = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2,
                      padding=7 // 2),  # + self.vfb,
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
        )

        # V2 layers
        self.V2_f = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=3 // 2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
        )

        # V4 layers
        self.V4_f = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=3 // 2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
        )

        # IT layers
        self.IT_f = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=3 // 2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
        )

        # head
        self.head = nn.Sequential(
            AdaptiveConcatPool2d(),
            nn.Flatten(),
            nn.BatchNorm1d(1024 * 2,
                           eps=1e-05,
                           momentum=0.1,
                           affine=True,
                           track_running_stats=True),
            nn.Dropout(p=w_dropout_1, inplace=False),
            nn.Linear(in_features=1024 * 2, out_features=512, bias=False),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(512,
                           eps=1e-05,
                           momentum=0.1,
                           affine=True,
                           track_running_stats=True),
            nn.Dropout(p=w_dropout_2, inplace=False),
            nn.Linear(in_features=512, out_features=2, bias=False),
        )

        self.fb = nn.Sequential(
            nn.Conv2d(1024, 3, kernel_size=3, stride=1, padding=221),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
        )

    def forward(self, inp):
        inp1 = inp[0]
        inp2 = inp[1]
        fov_inp = inp[2]

        # perihperal 1
        v1_p1 = self.V1_p(inp1)
        v2_p1 = self.V2_p(v1_p1)
        v4_p1 = self.V4_p(v2_p1)
        vIT_p1 = self.IT_p(v4_p1)

        # perihperal 1
        v1_p2 = self.V1_p(inp2)
        v2_p2 = self.V2_p(v1_p2)
        v4_p2 = self.V4_p(v2_p2)
        vIT_p2 = self.IT_p(v4_p2)

        out_cat = torch.cat((vIT_p1, vIT_p2), 1)

#         fb = self.fb(out_cat)
#         try:
#             v1_fov = self.V1_f(fb + fov_inp)
#         except:
#             v1_fov = self.V1_f(fov_inp)

#         v2_fov = self.V2_f(v1_fov)
#         v4_fov = self.V4_f(v2_fov)
#         vIT_fov = self.IT_f(v4_fov)

#         out_all = torch.cat((vIT_p1 + vIT_fov, vIT_p2 + vIT_fov), 1)
#         out_all = torch.cat((vIT_p1, vIT_p2, vIT_fov), 1)
        out_all = torch.cat((vIT_p1, vIT_p2), 1)
    
        out = self.head(out_all)

        return out
    

def label_func(path):
    split_name = path.stem.split("_")
    return 0 if split_name[-1] == split_name[-2] else 1


def get_img_tuple_no_noise(path):
    pair = Image.open(path)

    label = label_func(Path(path))
    orientation = os.path.basename(path).split("_")[-3]

    width, height = pair.size

    if orientation == "normal":
        left1, top1, right1, bottom1 = width - width // 4, 0, width, height // 4
        left2, top2, right2, bottom2 = 0, height - height // 4, width // 4, height
    else:
        left1, top1, right1, bottom1 = 0, 0, width // 4, height // 4
        left2, top2, right2, bottom2 = (
            width - width // 4,
            height - height // 4,
            width,
            height,
        )

    im1 = pair.crop((left1, top1, right1, bottom1)).resize((224, 224))
    im2 = pair.crop((left2, top2, right2, bottom2)).resize((224, 224))
    im3 = Image.new("RGB", (224, 224), (125, 125, 125))

    return (
        ToTensor()(PILImage(im1)),
        ToTensor()(PILImage(im2)),
        ToTensor()(PILImage(im3)),
        label,
    )


def get_img_tuple_noise(path):
    pair = Image.open(path)

    label = label_func(Path(path))
    orientation = os.path.basename(path).split("_")[-3]

    width, height = pair.size

    if orientation == "normal":
        left1, top1, right1, bottom1 = width - width // 4, 0, width, height // 4
        left2, top2, right2, bottom2 = 0, height - height // 4, width // 4, height
    else:
        left1, top1, right1, bottom1 = 0, 0, width // 4, height // 4
        left2, top2, right2, bottom2 = (
            width - width // 4,
            height - height // 4,
            width,
            height,
        )

    im1 = pair.crop((left1, top1, right1, bottom1)).resize((224, 224))
    im2 = pair.crop((left2, top2, right2, bottom2)).resize((224, 224))
    im3 = Image.new("RGB", (224, 224), (125, 125, 125))
    im3 = Image.fromarray(
        np.uint8(
            skimage.util.random_noise(
                skimage.img_as_float(im3), mode="s&p", amount=1) * 255))

    return (
        ToTensor()(PILImage(im1)),
        ToTensor()(PILImage(im2)),
        ToTensor()(PILImage(im3)),
        label,
    )


class ImageTuple(fastuple):
    @classmethod
    def create(cls, fns):
        return cls(fns)

    def show(self, ctx=None, **kwargs):
        t1, t2, t3 = self
        if (not isinstance(t1, Tensor) or not isinstance(t2, Tensor)
                or t1.shape != t2.shape):
            return ctx
        line = t1.new_zeros(t1.shape[0], t1.shape[1], 10)
        return show_image(torch.cat([t1, line, t2], dim=2), ctx=ctx, **kwargs)


def ImageTupleBlock():
    return TransformBlock(type_tfms=ImageTuple.create,
                          batch_tfms=IntToFloatTensor)


def get_tuples_no_noise(files):
    return [[
        get_img_tuple_no_noise(f)[0],
        get_img_tuple_no_noise(f)[1],
        get_img_tuple_no_noise(f)[2],
        get_img_tuple_no_noise(f)[3],
    ] for f in files]


def get_tuples_noise(files):
    return [[
        get_img_tuple_noise(f)[0],
        get_img_tuple_noise(f)[1],
        get_img_tuple_noise(f)[2],
        get_img_tuple_noise(f)[3],
    ] for f in files]


def get_x(t):
    return t[:3]


def get_y(t):
    return t[3]


def make_dls(stim_path, batch_sz=24, fov_noise=False):
    stim_path = Path(stim_path)
    pairs = glob.glob(os.path.join(stim_path, "*.png"))
    fnames = sorted(Path(s) for s in pairs)
    y = [label_func(item) for item in fnames]

    splitter = TrainTestSplitter(test_size=0.2,
                                 random_state=42,
                                 shuffle=True,
                                 stratify=y)
    splits = splitter(fnames)
    # splits = RandomSplitter()(fnames)
    if fov_noise:
        siamese = DataBlock(
            blocks=(ImageTupleBlock, CategoryBlock),
            get_items=get_tuples_noise,
            get_x=get_x,
            get_y=get_y,
            splitter=splitter,
        )
    else:
        siamese = DataBlock(
            blocks=(ImageTupleBlock, CategoryBlock),
            get_items=get_tuples_no_noise,
            get_x=get_x,
            get_y=get_y,
            splitter=splitter,
        )

    dls = siamese.dataloaders(
        fnames,
        bs=batch_sz,
        seed=seed,
        # shuffle=True,
        # device = 'cuda',
    )
    # check that train and test splits have balanced classes
    train_test = ["TRAIN", "TEST"]
    for train_test_id in [0, 1]:
        s = 0
        d = 0
        for item in dls.__getitem__(train_test_id).items:
            # print(label_from_path(item))
            # print(item)
            # print('---')
            if item[3] == 1:
                s += 1
            else:
                d += 1
        print(
            f"{train_test[train_test_id]} SET (same, diff): {str(s)}, {str(d)}"
        )
    return dls


def plot_filters_multi_channel(t, path=""):

    # get the number of kernels
    num_kernels = t.shape[0]

    # define number of columns for subplots
    num_cols = 12
    # rows = num of kernels
    num_rows = num_kernels

    # set the figure size
    fig = plt.figure(figsize=(num_cols, num_rows))

    # looping through all the kernels
    for i in range(t.shape[0]):
        ax1 = fig.add_subplot(num_rows, num_cols, i + 1)

        # for each kernel, we convert the tensor to numpy
        npimg = np.array(t[i].numpy(), np.float32)

        # standardize the numpy image
        npimg = (npimg - np.mean(npimg)) / np.std(npimg)
        npimg = np.minimum(1, np.maximum(0, (npimg + 0.5)))
        npimg = npimg.transpose((1, 2, 0))
        ax1.imshow(npimg)
        ax1.axis("off")
        ax1.set_title(str(i))
        ax1.set_xticklabels([])
        ax1.set_yticklabels([])

    plt.tight_layout()
    if path != "":
        plt.savefig(path)
    plt.show()


def init_weights(net):
    for m in net.modules():
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            m.weight.data.fill_(1)
            m.bias.data.zero_()
    return net


def make_cf(cf_y, cf_pred, cycle, epoch, path=""):
    plt.figure(figsize=(7, 7))
    cf_matrix = confusion_matrix(cf_y, cf_pred)
    df_cm = pd.DataFrame(
        cf_matrix,
        index=[i for i in ["Same", "Different"]],
        columns=[i for i in ["Same", "Different"]],
    )
    sn.heatmap(df_cm, annot=True, cbar=False, cmap="Blues", fmt="d")
    plt.suptitle(f"Epoch {cycle+1} x {epoch+1}")
    plt.xlabel("Predicted")
    plt.ylabel("Actual")
    if path != "":
        plt.savefig(os.path.join(path, f"cm_{cycle+1}x{epoch+1}.png"))
    plt.show()


def plot_losses(tr_loss, te_loss, cycle, epoch, path=""):
    plt.plot(tr_loss, label="Train")
    plt.plot(te_loss, label="Test")
    plt.suptitle(f"Losses\nEpoch {cycle+1} x {epoch+1}")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    if path != "":
        plt.savefig(os.path.join(path, f"loss_{cycle+1}x{epoch+1}.png"))
    plt.show()


def plot_acc(tr_acc, te_acc, cycle, epoch, path=""):
    plt.plot(tr_acc, label="Train")
    plt.plot(te_acc, label="Test")
    plt.suptitle(f"Accuracy\nEpoch {cycle+1} x {epoch+1}")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.legend()
    if path != "":
        plt.savefig(os.path.join(path, f"acc_{cycle+1}x{epoch+1}.png"))
    plt.show()

In [3]:


dls = make_dls(stim_path, batch_sz, fov_noise)
train_loader = dls.train
test_loader = dls.valid

# print('\nShowing first batch...')
# dls.show_batch(max_n = 2)
# plt.show()

criterion = nn.CrossEntropyLoss()

net = SiameseNetEncoderFB().cuda()
net = init_weights(net)

url = f'https://s3.amazonaws.com/cornet-models/cornet_z-5c427c9c.pth'
ckpt_data = torch.utils.model_zoo.load_url(url)

state_dict = {
    "V1_p.0.weight": ckpt_data['state_dict']['module.V1.conv.weight'],
    "V1_p.0.bias": ckpt_data['state_dict']['module.V1.conv.bias'],
    "V2_p.0.weight": ckpt_data['state_dict']['module.V2.conv.weight'],
    "V2_p.0.bias": ckpt_data['state_dict']['module.V2.conv.bias'],
    "V4_p.0.weight": ckpt_data['state_dict']['module.V4.conv.weight'],
    "V4_p.0.bias": ckpt_data['state_dict']['module.V4.conv.bias'],
    "V1_f.0.weight": ckpt_data['state_dict']['module.V1.conv.weight'],
    "V1_f.0.bias": ckpt_data['state_dict']['module.V1.conv.bias'],
    "V2_f.0.weight": ckpt_data['state_dict']['module.V2.conv.weight'],
    "V2_f.0.bias": ckpt_data['state_dict']['module.V2.conv.bias'],
    "V4_f.0.weight": ckpt_data['state_dict']['module.V4.conv.weight'],
    "V4_f.0.bias": ckpt_data['state_dict']['module.V4.conv.bias'],
    "IT_f.0.weight": ckpt_data['state_dict']['module.IT.conv.weight'],
    "IT_f.0.bias": ckpt_data['state_dict']['module.IT.conv.bias'],
    "IT_p.0.weight": ckpt_data['state_dict']['module.IT.conv.weight'],
    "IT_p.0.bias": ckpt_data['state_dict']['module.IT.conv.bias'],
}

net.V1_p[0].weight.requires_grad = False
net.V1_p[0].bias.requires_grad = False
net.V2_p[0].weight.requires_grad = False
net.V2_p[0].bias.requires_grad = False
net.V4_p[0].weight.requires_grad = False
net.V4_p[0].bias.requires_grad = False
net.V1_f[0].weight.requires_grad = False
net.V1_f[0].bias.requires_grad = False
net.V2_f[0].weight.requires_grad = False
net.V2_f[0].bias.requires_grad = False
net.V4_f[0].weight.requires_grad = False
net.V4_f[0].bias.requires_grad = False
net.IT_f[0].weight.requires_grad = False
net.IT_f[0].bias.requires_grad = False
net.IT_p[0].weight.requires_grad = False
net.IT_p[0].bias.requires_grad = False
net.fb[0].weight.requires_grad = True
net.fb[0].bias.requires_grad = True
# net.head[0].bias.requires_grad = False

net.load_state_dict(state_dict, strict=False)
net = nn.DataParallel(net)
net.to('cuda')

params_to_update = net.parameters()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, params_to_update),
                       lr=lr_min,
                       weight_decay=weight_decay)

print("\nTrain/Test started!")

for cycle in range(cycles):
    tr_loss = []
    tr_acc = []
    te_loss = []
    te_acc = []

    for epoch in range(epochs):
        # TRAIN
        net.train()

        tr_running_loss = 0.0
        tr_correct = 0
        tr_total = 0
        start = time.time()
        for (inputs, labels) in train_loader:
            optimizer.zero_grad()
            out = net(inputs)
            _, pred = torch.max(out, 1)
            loss = criterion(out, labels)
            loss.backward()
            optimizer.step()
            tr_running_loss += loss.item()
            tr_total += labels.size(0)
            tr_correct += (pred == labels).sum().item()

        tr_loss.append(tr_running_loss)
        tr_acc.append(100 * tr_correct / tr_total)

        # TEST
        net.eval()

        te_running_loss = 0.0
        te_correct = 0
        te_total = 0
        cf_pred = []
        cf_y = []
        with torch.no_grad():
            for (inputs, labels) in test_loader:
                out = net(inputs)
                _, pred = torch.max(out, 1)
                loss = criterion(out, labels)
                te_running_loss += loss.item()
                te_total += labels.size(0)
                te_correct += (pred == labels).sum().item()
                cf_y += labels.cpu().detach().tolist()
                cf_pred += pred.cpu().detach().tolist()

            te_acc.append(100 * te_correct / te_total)
            te_loss.append(te_running_loss)
            end = time.time() - start
            log_msg = f"%5d / %5d TRAIN/TEST losses: \t %.8f \t %.8f \t\t acc: \t %.2f %% \t %.2f %% \t\t time: {round(end,3)}" % (
                cycle + 1,
                epoch + 1,
                tr_running_loss,
                te_running_loss,
                100 * tr_correct / tr_total,
                100 * te_correct / te_total,
            )
            print(log_msg)

path = ''
make_cf(cf_y, cf_pred, cycle, epoch, path)
plot_losses(tr_loss, te_loss, cycle, epoch, path)
plot_acc(tr_acc, te_acc, cycle, epoch, path)

# weights = net.module.V1_f[0].weight.data.cpu()
# plot_filters_multi_channel(weights, path)
# print(weights.shape)

# weights = net.module.V1_p[0].weight.data.cpu()
# plot_filters_multi_channel(weights, path)
# print(weights.shape)

# weights = net.module.fb[0].weight.data.cpu()
# plot_filters_multi_channel(weights, path)
# print(weights.shape)

Due to IPython and Windows limitation, python multiprocessing isn't available now.
So `number_workers` is changed to 0 to avoid getting stuck
TRAIN SET (same, diff): 2688, 2688
TEST SET (same, diff): 672, 672

Train/Test started!
    1 /     1 TRAIN/TEST losses: 	 341.26323533 	 39.67501360 		 acc: 	 49.39 % 	 54.91 % 		 time: 30.691


RuntimeError: CUDA error: the launch timed out and was terminated