In [2]:
"""
Authors:
Nawras Abbas    315085043
Michael Bikman  317920999
"""
import torch
from torch import nn
from torch.distributions import Bernoulli
import torch.nn.functional as F
import torch.nn as nn
import os
import numpy as np
import torch.nn
from torch.autograd import Variable
from torch.nn import functional
import logging
import os
import sys
from datetime import datetime
import math
import random
import time
import torch
from tqdm import tqdm
import logging
import os
import pickle

LOG_ENABLED = True

log = print
if LOG_ENABLED:
    log = logging.info

# current time
TIMESTAMP = datetime.now().strftime("%Y_%m_%d-%H_%M_%S")

DATA_DIR = '/content/drive/MyDrive/Colab Notebooks/project'
TRAIN_DATA_FILE = r"mini-imagenet-cache-train.pkl"
VALID_DATA_FILE = r"mini-imagenet-cache-val.pkl"
TEST_DATA_FILE = r"mini-imagenet-cache-test.pkl"

DEVICE = torch.device('cpu')
if torch.cuda.is_available():
    DEVICE = torch.device('cuda')



# ********************************************* resnet12 *********************************************
# ********************************************* resnet12 *********************************************
# ********************************************* resnet12 *********************************************
# ********************************************* resnet12 *********************************************
# ********************************************* resnet12 *********************************************



class DropBlock(nn.Module):
    def __init__(self, block_size):
        super(DropBlock, self).__init__()
        self.block_size = block_size

    def forward(self, x, gamma):
        # shape: (bsize, channels, height, width)

        if self.training:
            batch_size, channels, height, width = x.shape

            bernoulli = Bernoulli(gamma)
            mask = bernoulli.sample(
                (batch_size, channels, height - (self.block_size - 1), width - (self.block_size - 1))).to(DEVICE)
            block_mask = self._compute_block_mask(mask)
            countM = block_mask.size()[0] * block_mask.size()[1] * block_mask.size()[2] * block_mask.size()[3]
            count_ones = block_mask.sum()

            return block_mask * x * (countM / count_ones)
        else:
            return x

    def _compute_block_mask(self, mask):
        left_padding = int((self.block_size - 1) / 2)
        right_padding = int(self.block_size / 2)

        non_zero_idxs = mask.nonzero()
        nr_blocks = non_zero_idxs.shape[0]

        offsets = torch.stack(
            [
                torch.arange(self.block_size).view(-1, 1).expand(self.block_size, self.block_size).reshape(-1),
                torch.arange(self.block_size).repeat(self.block_size),  # - left_padding
            ]
        ).t().to(DEVICE)
        offsets = torch.cat((torch.zeros(self.block_size ** 2, 2).to(DEVICE).long(), offsets.long()), 1)

        if nr_blocks > 0:
            non_zero_idxs = non_zero_idxs.repeat(self.block_size ** 2, 1)
            offsets = offsets.repeat(nr_blocks, 1).view(-1, 4)
            offsets = offsets.long()

            block_idxs = non_zero_idxs + offsets
            padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding))
            padded_mask[block_idxs[:, 0], block_idxs[:, 1], block_idxs[:, 2], block_idxs[:, 3]] = 1.
        else:
            padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding))

        block_mask = 1 - padded_mask  # [:height, :width]
        return block_mask


# This ResNet network was designed following the practice of the following papers:
# TADAM: Task dependent adaptive metric for improved few-shot learning (Oreshkin et al., in NIPS 2018) and
# A Simple Neural Attentive Meta-Learner (Mishra et al., in ICLR 2018).

def conv3x3(in_planes, out_planes, stride=1):
    """
    3x3 convolution with padding
    """
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


class ResNer12_BasicBlock(nn.Module):

    def __init__(self, in_planes, planes, stride=1, down_sample=None, drop_rate=0.0, drop_block=False, block_size=1):
        super(ResNer12_BasicBlock, self).__init__()
        self.conv1 = conv3x3(in_planes, planes)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.LeakyReLU(0.1, inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = conv3x3(planes, planes)
        self.bn3 = nn.BatchNorm2d(planes)
        self.max_pool = nn.MaxPool2d(stride)
        self.down_sample = down_sample
        self.stride = stride
        self.drop_rate = drop_rate
        self.num_batches_tracked = 0
        self.drop_block = drop_block
        self.block_size = block_size
        self.DropBlock = DropBlock(block_size=self.block_size)

    def forward(self, x):
        self.num_batches_tracked += 1

        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.down_sample is not None:
            residual = self.down_sample(x)
        out += residual
        out = self.relu(out)
        out = self.max_pool(out)

        if self.drop_rate > 0:
            if self.drop_block:
                feat_size = out.size()[2]
                keep_rate = max(1.0 - self.drop_rate / (20 * 2000) * self.num_batches_tracked, 1.0 - self.drop_rate)
                gamma = (1 - keep_rate) / self.block_size ** 2 * feat_size ** 2 / (feat_size - self.block_size + 1) ** 2
                out = self.DropBlock(out, gamma=gamma)

        return out


class ResNet(nn.Module):

    def __init__(self, block, drop_rate, drop_block_size=5):
        self.in_planes = 3
        super(ResNet, self).__init__()

        filters = [64, 64, 64, 64]

        self.layer1 = self._make_layer(block, filters[0], stride=2, drop_rate=drop_rate)
        self.layer2 = self._make_layer(block, filters[1], stride=2, drop_rate=drop_rate)
        self.layer3 = self._make_layer(block, filters[2], stride=2, drop_rate=drop_rate, drop_block=True,
                                       block_size=drop_block_size)
        self.layer4 = self._make_layer(block, filters[3], stride=2, drop_rate=drop_rate, drop_block=True,
                                       block_size=drop_block_size)

        self.drop_rate = drop_rate
        self.flatten = nn.Sequential(
            nn.Flatten(start_dim=1)
        )
  

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, stride=1, drop_rate=0.0, drop_block=False, block_size=1):
        down_sample = None
        if stride != 1 or self.in_planes != planes:
            down_sample = nn.Sequential(
                nn.Conv2d(self.in_planes, planes,
                          kernel_size=1, stride=1, bias=False),
                nn.BatchNorm2d(planes),
            )
        layers = [block(self.in_planes, planes, stride, down_sample, drop_rate, drop_block, block_size)]
        self.in_planes = planes

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.flatten(x)
        return x


def resnet12(drop_rate):
    """
    Constructs a ResNet-12 model
    """
    return ResNet(ResNer12_BasicBlock, drop_rate)





# ********************************************* resnet18 *********************************************
# ********************************************* resnet18 *********************************************
# ********************************************* resnet18 *********************************************
# ********************************************* resnet18 *********************************************
# ********************************************* resnet18 *********************************************




class ResNer18_BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResNer18_BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.Sequential(
            nn.ReLU(inplace=True),
        )
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += residual
        out = self.relu(out)
        return out


class BasicBlockDownSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(BasicBlockDownSample, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=2, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.Sequential(
            nn.ReLU(inplace=True),
        )
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, bias=False, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.down_sample = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2, bias=False),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        residual = self.down_sample(x)
        out += residual
        out = self.relu(out)
        return out


class ResNet18(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.Sequential(
            nn.ReLU(inplace=True),
        )
        self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = nn.Sequential(
            ResNer18_BasicBlock(64, 64),
            ResNer18_BasicBlock(64, 64)
        )
        self.layer2 = nn.Sequential(
            ResNer18_BasicBlock(64, 128),
            ResNer18_BasicBlock(128, 128)
        )
        self.layer3 = nn.Sequential(
            ResNer18_BasicBlock(128, 256),
            ResNer18_BasicBlock(256, 256)
        )
        self.layer4 = nn.Sequential(
            ResNer18_BasicBlock(256, 512),
            ResNer18_BasicBlock(512, 512)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(start_dim=1)
        )

    def forward(self, x):  # (200,3,84,84)
        x = self.conv1(x)  # (200,64,42,42)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.max_pool(x)  # (200,64,21,21)

        x = self.layer1(x)  # out (200,64,21,21)
        x = self.layer2(x)  # out (200,128,11,11)
        x = self.layer3(x)  # out (200,256,6,6)
        x = self.layer4(x)  # out (200,512,3,3)
        x = self.classifier(x)  # out (200, 4608)

        return x


# ********************************************* models *********************************************
# ********************************************* models *********************************************
# ********************************************* models *********************************************
# ********************************************* models *********************************************
# ********************************************* models *********************************************




def euclidean_dist(x, y):
    """
    TBD
    :param x: size [n_query_total, out_dim=1600] - queries
    :param y: size [n_ways, out_dim=1600] - prototypes
    :return:
    """
    n = x.size(0)  # total number of query points = n_query_total
    m = y.size(0)  # number of classes = n_ways
    d = x.size(1)  # dimension of pic embedding = 1600 for mini-ImageNet
    if d != y.size(1):
        raise ValueError(f'Pic embedding for prototype {y.size(1)} and query {d} data arent equal')

    x = x.unsqueeze(1).expand(n, m, d)  # size = [n_query_total, n_ways, 1600]
    y = y.unsqueeze(0).expand(n, m, d)  # size = [n_query_total, n_ways, 1600]

    return torch.pow(x - y, 2).sum(2)


def mahalanobis_dist(x, y):
    """
    TBD
    :param x: size [n_query_total, out_dim=1600] - queries
    :param y: size [n_ways, out_dim=1600] - prototypes
    :return:
    """
    n_queries = 15
    n_query_total = x.size(0)
    n_ways = y.size(0)  # number of classes = n_ways
    res = torch.zeros(n_query_total, n_ways).to(DEVICE)  # size = [n_query_total, n_ways]
    queries_per_class = x.split(n_queries, dim=0)  # (10, [15, 1600])
    prototypes_per_class = y.split(1, dim=0)  # (10, [1,1600])
    batches = int(n_query_total / n_queries)
    for class_ndx in range(n_ways):
        class_queries = queries_per_class[class_ndx].detach().cpu()
        proto = prototypes_per_class[class_ndx]
        for query_batch_ndx in range(batches):
            query_batch = queries_per_class[query_batch_ndx]
            cov_arr = np.cov(class_queries.T)
            cov = torch.from_numpy(cov_arr).to(DEVICE)
            cov_diag = torch.diag(cov)
            cov = torch.diag(cov_diag)

            for query_ndx in range(n_queries):
                query = query_batch[query_ndx, :]
                dist = mahalanobis(proto, query, cov)
                q = n_queries * query_batch_ndx + query_ndx
                res[q, class_ndx] = dist.item()
    return res


def mahalanobis(u, v, cov):
    delta = (u - v).double()
    delta_trans = torch.transpose(delta, 0, 1).double()
    cov_inverse = torch.inverse(cov).double()
    mult = torch.matmul(delta, cov_inverse)
    m = torch.matmul(mult, delta_trans)
    return torch.sqrt(m)


class ProtoNetSimple(nn.Module):
    def __init__(self, num_filters=128):
        super().__init__()

        self.file_name = 'ProtoNetSimple.pt'

        self.block1 = self._cnn_block(3, num_filters)
        self.block2 = self._cnn_block(num_filters, num_filters)
        self.block3 = self._cnn_block(num_filters, num_filters)
        self.block4 = self._cnn_block(num_filters, num_filters)

        self.classifier = nn.Sequential(
            nn.Flatten(start_dim=1)
        )

        self.alpha = nn.Parameter(torch.tensor(1.0, requires_grad=True))

    @staticmethod
    def _cnn_block(in_channels, out_channels):
        block = torch.nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2)
        )
        return block

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        vector = self.classifier(x)
        return vector

    def loss(self, episode, image_data):
        support_data = episode.get_support_sample(image_data)
        query_data = episode.get_query_sample(image_data)

        xs = Variable(support_data)
        # xs size = [n_ways, n_shots, channels=3, width=84, height=84]
        xq = Variable(query_data)
        # xq size = [n_ways, n_query_points, channels=3, width=84, height=84]

        n_class = xs.size(0)
        if xq.size(0) != n_class:
            raise ValueError(f'Number of classes for support {xs.size(0)} and query {xq.size(0)} data is not equal')
        n_support = xs.size(1)
        n_query = xq.size(1)

        target_indices = torch.arange(0, n_class).view(n_class, 1, 1).expand(n_class, n_query, 1).long().to(DEVICE)
        target_indices = Variable(target_indices, requires_grad=False)

        support_pic_size = xs.size()[2:]
        query_pic_size = xq.size()[2:]
        if query_pic_size != support_pic_size:
            raise ValueError(f'Pic sizes for support {support_pic_size} and query {query_pic_size} data arent equal')
        n_support_total = n_class * n_support
        n_query_total = n_class * n_query
        xs_view = xs.view(n_support_total, *support_pic_size)
        xq_view = xq.view(n_query_total, *query_pic_size)
        x = torch.cat([xs_view, xq_view], 0).float().to(DEVICE)  # input for the model
        # x = has dimension of [n_support_total + n_query_total, channels=3, width=84, height=84]

        z = self.forward(x)  # output with dimension [n_support_total + n_query_total, 1600]
        z_dim = z.size(-1)

        z_support = z[:n_support_total]  # size = [n_support_total, 1600]
        # prototype = average all embeddings from support set
        prototypes_per_class = z_support.view(n_class, n_support, z_dim).mean(1)  # size = [n_class, 1600]
        query_vectors = z[n_support_total:]  # size = [n_query_total, 1600]

        dists_per_class = euclidean_dist(query_vectors, prototypes_per_class)  # size = [n_query_total, n_ways]
        # dists_per_class = mahalanobis_dist(query_vectors, prototypes_per_class)  # size = [n_query_total, n_ways]

        # alpha used here ----------------------------------------------------------
        # alpha parameter is used to scale the distance metric to obtain better results
        if self.alpha is not None:
            dists_per_class = torch.mul(self.alpha, dists_per_class)
        # --------------------------------------------------------------------------

        log_p_y = torch.nn.functional.log_softmax(-dists_per_class, dim=1).view(n_class, n_query, -1)  # log(p(y=k|x))
        # log_p_y = size [n_class = n_ways, n_query, n_class]
        loss_per_query = -log_p_y.gather(2, target_indices).squeeze().view(-1)  # size = [n_query_total]
        loss_val = loss_per_query.mean()  # average loss for all queries
        _, y_hat = log_p_y.max(dim=2)  # returns tuple (max values, argmax indices)
        # y_hat size = [n_class, n_query]

        # calculate accuracy = number of matches between y_hat indices and ground truth target_indices
        acc_val = torch.eq(y_hat, target_indices.squeeze()).float().mean()

        return loss_val, {
            'loss': loss_val.item(),
            'acc': acc_val.item()
        }

    def save(self, report_folder):
        model_file = os.path.join(report_folder, self.file_name)
        torch.save(self.state_dict(), model_file)


class ProtoNetComplex(ProtoNetSimple):

    def __init__(self, num_filters=64):
        super().__init__(num_filters)

        self.file_name = 'ProtoNetComplex.pt'

        self.block1 = self._cnn_block(3, num_filters)
        self.block2 = self._cnn_block(num_filters, num_filters)
        self.block3 = self._cnn_block(num_filters, num_filters)
        self.block4 = self._cnn_block(num_filters, num_filters)

        self.classifier = nn.Sequential(
            nn.Flatten(start_dim=1),
            nn.Linear(1600, 1600),
            nn.Dropout(0.5),
            nn.ReLU(),
            nn.Linear(1600, 128)
        )


class ProtoNetRes18(ProtoNetSimple):
    def __init__(self):
        super().__init__()
        self.res_net = ResNet18()
        self.file_name = 'ProtoNetRes18.pt'

    def forward(self, x):
        x = self.res_net.forward(x)
        return x


class ProtoNetRes12(ProtoNetSimple):
    def __init__(self, drop_rate=0.1):
        super().__init__()
        self.res_net = resnet12(drop_rate=drop_rate)
        self.file_name = 'ProtoNetRes12.pt'

    def forward(self, x):
        x = self.res_net.forward(x)
        return x


class ResNetSimple(ProtoNetSimple):
    """
    Densely Connected CNN
    """

    def __init__(self, num_filters=64):
        super().__init__()
        self.file_name = 'ResNetSimple.pt'
        self.maxPool = nn.MaxPool2d(kernel_size=2)
        self.alpha = nn.Parameter(torch.tensor(1.0, requires_grad=True))
        self.relu = nn.ReLU(inplace=True)
        self.flatten = nn.Sequential(
            nn.Flatten(start_dim=1)
        )
        self.BN = nn.BatchNorm2d(num_filters)

        self.block1 = self._cnn_block(3, num_filters)
        self.blockRes1 = self._cnn_block_Res(3, num_filters)

        self.block2 = self._cnn_block(num_filters, num_filters)
        self.blockRes2 = self._cnn_block_Res(num_filters, num_filters)

    @staticmethod
    def _cnn_block(in_channels, out_channels):
        block = torch.nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
        )
        return block

    @staticmethod
    def _cnn_block_Res(in_channels, out_channels):
        block = torch.nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels)
        )
        return block

    def forward(self, x):
        # **************************   block 1   *********************************
        # on the original input of the block we run a conv layer and Batch normalization
        residual = x
        residual = self.blockRes1(residual)

        # on the original input run conv -> relu -> conv -> relu -> conv
        # and then sum the outputs of the above two blocks and then run relu and MP
        # run the same architecture on the other 3 blocks
        x = self.block1(x)
        x = x + residual
        x = self.relu(x)
        x = self.maxPool(x)

        # **************************   block 2   *********************************
        residual = x
        residual = self.blockRes2(residual)
        x = self.block2(x)
        x = x + residual
        x = self.relu(x)
        x = self.maxPool(x)

        # **************************   block 3   *********************************
        residual = x
        residual = self.blockRes2(residual)
        x = self.block2(x)
        x = x + residual
        x = self.relu(x)
        x = self.maxPool(x)

        # **************************   block 4   *********************************
        residual = x
        residual = self.blockRes2(residual)
        x = self.block2(x)
        x = x + residual
        x = self.relu(x)
        x = self.maxPool(x)

        vector = self.flatten(x)
        return vector






# ********************************************* parameters *********************************************
# ********************************************* parameters *********************************************
# ********************************************* parameters *********************************************
# ********************************************* parameters *********************************************
# ********************************************* parameters *********************************************




class Episode(object):
    """
    Class representation for a single episode
    """

    def __init__(self):
        self.indexes = {}  # class_name -> (support pics indices, query pics indices)
        self.SUPPORT_DATA_NDX = 0
        self.QUERY_DATA_NDX = 1

    def get_support_sample(self, image_data):
        return self._get_data_sample(image_data, self.SUPPORT_DATA_NDX)

    def get_query_sample(self, image_data):
        return self._get_data_sample(image_data, self.QUERY_DATA_NDX)

    def add_indices(self, class_name, support_ndxs, query_ndxs):
        self.indexes[class_name] = (support_ndxs, query_ndxs)

    def _get_data_sample(self, image_data, data_index):
        all_samples = []
        for class_name in self.indexes.keys():
            ndxs = self.indexes[class_name][data_index]
            sample_indices = torch.tensor(ndxs)
            sample = torch.index_select(image_data, 0, sample_indices)
            sample = sample.transpose(1, 3).transpose(2, 3)
            all_samples.append(sample)
        result = torch.stack(all_samples, dim=0)
        return result


class RunParameters(object):
    """
    Use this class for train + validation run
    """

    def __init__(self):
        self.model = None
        self.epochs = 0
        self.loss = None
        self.optimizer = None
        self.scheduler = None
        self.report_path = None  # path where to save PT file and logs
        self.train_data = None
        self.val_data = None
        self.episodes_per_epoch = 0
        self.patience = 0
        self.n_ways = 0
        self.n_support_examples = 0
        self.n_query_examples = 0

    def __str__(self):
        return '--------Run params--------\n' \
               f'Epochs: {self.epochs}\n' \
               f'Loss: {self.loss}\n' \
               f'Optimizer:{self.optimizer}\n' \
               f'Scheduler:{self.scheduler}\n' \
               f'report_path:{self.report_path}\n' \
               f'episodes_per_epoch:{self.episodes_per_epoch}\n' \
               f'patience:{self.patience}\n' \
               f'n_ways:{self.n_ways}\n' \
               f'n_support_examples:{self.n_support_examples}\n' \
               f'n_query_examples:{self.n_query_examples}\n' \
               '-----------------------------\n'


class TrainResult(object):
    """
    Result of train + validation run
    """

    def __init__(self):
        self.train_loss_per_epoch = []
        self.train_accuracy_per_epoch = []
        self.validation_loss_per_epoch = []
        self.validation_accuracy_per_epoch = []
        self.best_epoch = 0

    def train_loss_min(self):
        return min(self.train_loss_per_epoch)

    def valid_loss_min(self):
        return min(self.validation_loss_per_epoch)

    def train_accuracy(self):
        return max(self.train_accuracy_per_epoch)

    def validation_accuracy(self):
        return max(self.validation_accuracy_per_epoch)

    def __str__(self):
        return '--------Run result--------\n' \
               f'Train Loss: {self.train_loss_min()}\n' \
               f'Train Acc: {self.train_accuracy()}\n' \
               f'Valid Loss:{self.valid_loss_min()}\n' \
               f'Valid Acc:{self.validation_accuracy()}\n' \
               f'Best Epoch:{self.best_epoch}\n' \
               '-----------------------------\n'


class TestResult:
    """
    Used for test evaluation
    """

    def __init__(self):
        self.acc = 0  # accuracy
        self.loss = 0  # loss

    def __str__(self):
        return '--------Test result--------\n' \
               f'Test Acc: {self.acc}\n' \
               f'Test Loss:{self.loss}\n' \
               '-----------------------------\n'




# ********************************************* utils *********************************************
# ********************************************* utils *********************************************
# ********************************************* utils *********************************************
# ********************************************* utils *********************************************
# ********************************************* utils *********************************************



def setup_reports():
    """
    Create setup for logger - both to file and to console
    Also model will be saved in the same folder with logs
    """
    report_root = 'reports'
    if not os.path.exists(report_root):
        os.makedirs(report_root)
    report_folder = f'{TIMESTAMP}_report'
    report_path = os.path.join(report_root, report_folder)
    if not os.path.exists(report_path):
        os.makedirs(report_path)

    logging.getLogger('matplotlib.font_manager').disabled = True
    logger = logging.getLogger()
    logger.setLevel(logging.DEBUG)
    stdout_handler = logging.StreamHandler(sys.stdout)
    log_file = f'{TIMESTAMP}.log'
    output_file_handler = logging.FileHandler(f'{report_path}\{log_file}')
    logger.addHandler(output_file_handler)
    logger.addHandler(stdout_handler)
    return report_path





# ********************************************* engine *********************************************
# ********************************************* engine *********************************************
# ********************************************* engine *********************************************
# ********************************************* engine *********************************************
# ********************************************* engine *********************************************


def choose_episode_classes(class_dict, n_ways):
    """
    Choose which classes will participate in the episode
    :param class_dict: dictionary of all classes
    :param n_ways: number of classes per episode
    :return:
    """
    # select classes for episode (random uniform)
    all_classes = class_dict.keys()
    classes = random.sample(all_classes, n_ways)
    return classes


def create_episodes(class_dict, n_episodes, n_ways, n_supports, n_queries):
    """
    TBD
    :param class_dict:
    :param n_episodes:
    :param n_ways:
    :param n_supports:
    :param n_queries:
    :return:
    """
    episodes = []

    for e in range(n_episodes):
        try:
            classes = choose_episode_classes(class_dict, n_ways)
        except ValueError:
            continue

        episode = Episode()
        for c in classes:
            class_ndxs = class_dict[c]
            selected_ndxs = random.sample(class_ndxs, n_supports + n_queries)

            support_ndxs = selected_ndxs[:n_supports]
            query_ndxs = selected_ndxs[n_supports:]
            episode.add_indices(c, support_ndxs, query_ndxs)

        episodes.append(episode)

    return episodes


def train(parameters, log):
    """
    Run training procedure
    :param log: logger to file / console
    :param parameters: training parameters context
    :return: train result
    """
    log('Train started')
    train_res = TrainResult()

    best_loss = math.inf
    epochs_since_best = 0
    for ep in range(parameters.epochs):
        epoch_start_time = time.time()

        train_loop(ep, parameters, train_res, log)
        validation_loop(parameters, train_res, log, f"validation")
        if parameters.scheduler is not None:
            parameters.scheduler.step()

        # early stopping
        epoch_loss = train_res.validation_loss_per_epoch[-1].item()
        if epoch_loss < best_loss:
            epochs_since_best = 0
            best_loss = epoch_loss
            parameters.model.save(parameters.report_path)
            log('best model saved')
        else:
            epochs_since_best = epochs_since_best + 1
        if epochs_since_best > parameters.patience:
            # update result and get out
            train_res.best_epoch = ep - parameters.patience
            return train_res

        elapsed_time = time.time() - epoch_start_time
        log(f'time: {elapsed_time:5.2f} sec')

    return train_res


def train_loop(ep, parameters, train_res, log):
    """
    TBD
    :param ep:
    :param parameters:
    :param train_res:
    :param log:
    :return:
    """
    model = parameters.model
    optimizer = parameters.optimizer
    n_ways = parameters.n_ways
    n_episodes = parameters.episodes_per_epoch
    n_supports = parameters.n_support_examples
    n_queries = parameters.n_query_examples
    train_image_data = torch.from_numpy(parameters.train_data['image_data'])
    train_class_dict = parameters.train_data['class_dict']

    model.train()
    train_batch_losses = []
    train_batch_accuracies = []
    train_episodes = create_episodes(train_class_dict, n_episodes, n_ways, n_supports, n_queries)
    log(f'Epoch {ep + 1}')
    for episode in tqdm(train_episodes, desc=f"Epoch {ep + 1}/{parameters.epochs} train"):
        optimizer.zero_grad()
        loss, output = model.loss(episode, train_image_data)
        train_batch_losses.append(output['loss'])
        train_batch_accuracies.append(output['acc'])
        loss.backward()
        optimizer.step()
    epoch_train_loss = torch.mean(torch.tensor(train_batch_losses))
    epoch_train_acc = torch.mean(torch.tensor(train_batch_accuracies))
    log(f'train loss: {epoch_train_loss:.6f}')
    log(f'train acc: {epoch_train_acc:.6f}')
    # update train losses in result here
    train_res.train_accuracy_per_epoch.append(epoch_train_acc)
    train_res.train_loss_per_epoch.append(epoch_train_loss)


def validation_loop(parameters, train_res, log, desc):
    """
    TBD
    :param desc:
    :param parameters:
    :param train_res:
    :param log:
    :return:
    """
    model = parameters.model
    n_ways = parameters.n_ways
    n_episodes = parameters.episodes_per_epoch
    n_supports = parameters.n_support_examples
    n_queries = parameters.n_query_examples
    val_image_data = torch.from_numpy(parameters.val_data['image_data'])
    val_class_dict = parameters.val_data['class_dict']

    model.eval()
    val_batch_losses = []
    val_batch_accuracies = []
    with torch.no_grad():
        val_episodes = create_episodes(val_class_dict, n_episodes, n_ways, n_supports, n_queries)
        for episode in tqdm(val_episodes, desc=desc):
            _, output = model.loss(episode, val_image_data)
            val_batch_losses.append(output['loss'])
            val_batch_accuracies.append(output['acc'])
    epoch_val_loss = torch.mean(torch.tensor(val_batch_losses))
    epoch_val_acc = torch.mean(torch.tensor(val_batch_accuracies))
    log(f'{desc} loss: {epoch_val_loss:.6f}')
    log(f'{desc} acc: {epoch_val_acc:.6f}')
    # update losses and accuracies in result here
    train_res.validation_accuracy_per_epoch.append(epoch_val_acc)
    train_res.validation_loss_per_epoch.append(epoch_val_loss)


def test(parameters, log):
    """
    TBD
    :param parameters:
    :param log:
    :return:
    """
    log('Test started')
    test_res = TrainResult()
    epoch_start_time = time.time()
    validation_loop(parameters, test_res, log, f"Test")
    elapsed_time = time.time() - epoch_start_time
    log(f'time: {elapsed_time:5.2f} sec')
    result = TestResult()
    result.acc = test_res.validation_accuracy_per_epoch[-1].item()
    result.loss = test_res.validation_loss_per_epoch[-1].item()
    return result





# ********************************************* run_train *********************************************
# ********************************************* run_train *********************************************
# ********************************************* run_train *********************************************
# ********************************************* run_train *********************************************
# ********************************************* run_train *********************************************



def load_train_data():
    """
    Load train split data from Pickle file
    :return: data dictionary
    """
    with open(os.path.join(DATA_DIR, TRAIN_DATA_FILE), "rb") as data_file:
        data = pickle.load(data_file)
    log(f'Loaded train data: {TRAIN_DATA_FILE}')
    return data


def load_validation_data():
    """
    Load validation split data from Pickle file
    :return: data dictionary
    """
    with open(os.path.join(DATA_DIR, VALID_DATA_FILE), "rb") as data_file:
        data = pickle.load(data_file)
    log(f'Loaded validation data: {VALID_DATA_FILE}')
    return data


def main():
    """
    TBD
    :return: TBD
    """
    report_path = setup_reports()
    log(DEVICE)

    num_filters = 64
    log(f'Num Filters: {num_filters}')
    drop_rate = 0.2
    log(f'Drop block rate: {drop_rate}')

    model = ProtoNetRes12().to(DEVICE)

    log(model)

    log(f'Alpha:{list(model.parameters())[0].item()}')

    lr = 0.002
    log(f'LR:{lr}')
    step = 50
    log(f'LR Step:{step}')
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step, gamma=0.5, last_epoch=-1)

    # load train & validation data
    train_data = load_train_data()
    val_data = load_validation_data()
    train_params = RunParameters()
    train_params.epochs = 10000  # max number of epochs
    train_params.optimizer = optimizer
    train_params.scheduler = scheduler
    train_params.train_data = train_data
    train_params.val_data = val_data

    train_params.n_ways = 5  # ways
    train_params.patience = 200
    train_params.episodes_per_epoch = 100
    train_params.n_support_examples = 5  # shots
    train_params.n_query_examples = 15
    train_params.model = model
    train_params.report_path = report_path
    log(train_params)
    train_res = train(train_params, log)
    log(train_res)

In [None]:
if __name__ == '__main__':
    main()
    print('OK')

cuda
Num Filters: 64
Drop block rate: 0.2
ProtoNetRes12(
  (block1): Sequential(
    (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (block2): Sequential(
    (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (block3): Sequential(
    (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (block4): Sequential(


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
Epoch 1/10000 train: 100%|██████████| 100/100 [00:19<00:00,  5.03it/s]

train loss: 41.196465
train acc: 0.431733



validation: 100%|██████████| 100/100 [00:06<00:00, 14.99it/s]

validation loss: 7.933660
validation acc: 0.400667
best model saved
time: 26.66 sec
Epoch 2



Epoch 2/10000 train: 100%|██████████| 100/100 [00:20<00:00,  4.97it/s]

train loss: 3.996119
train acc: 0.442267



validation: 100%|██████████| 100/100 [00:06<00:00, 14.67it/s]

validation loss: 2.310611
validation acc: 0.408800
best model saved
time: 27.02 sec
Epoch 3



Epoch 3/10000 train: 100%|██████████| 100/100 [00:20<00:00,  4.87it/s]

train loss: 1.792603
train acc: 0.441733



validation: 100%|██████████| 100/100 [00:07<00:00, 14.14it/s]

validation loss: 1.526839
validation acc: 0.416800
best model saved
time: 27.71 sec
Epoch 4



Epoch 4/10000 train: 100%|██████████| 100/100 [00:21<00:00,  4.76it/s]

train loss: 1.448980
train acc: 0.461333



validation: 100%|██████████| 100/100 [00:06<00:00, 14.52it/s]

validation loss: 1.450800
validation acc: 0.413867
best model saved
time: 28.06 sec
Epoch 5



Epoch 5/10000 train: 100%|██████████| 100/100 [00:21<00:00,  4.63it/s]

train loss: 1.303084
train acc: 0.479067



validation: 100%|██████████| 100/100 [00:06<00:00, 14.41it/s]

validation loss: 1.386311
validation acc: 0.437200
best model saved
time: 28.61 sec
Epoch 6



Epoch 6/10000 train: 100%|██████████| 100/100 [00:21<00:00,  4.70it/s]

train loss: 1.296773
train acc: 0.489333



validation: 100%|██████████| 100/100 [00:06<00:00, 14.57it/s]

validation loss: 1.350626
validation acc: 0.446133
best model saved
time: 28.21 sec
Epoch 7



Epoch 7/10000 train: 100%|██████████| 100/100 [00:21<00:00,  4.72it/s]

train loss: 1.263672
train acc: 0.493600



validation: 100%|██████████| 100/100 [00:06<00:00, 14.50it/s]

validation loss: 1.341240
validation acc: 0.450533
best model saved
time: 28.15 sec
Epoch 8



Epoch 8/10000 train: 100%|██████████| 100/100 [00:21<00:00,  4.67it/s]

train loss: 1.282258
train acc: 0.479467



validation: 100%|██████████| 100/100 [00:06<00:00, 14.51it/s]

validation loss: 1.355444
validation acc: 0.439600
time: 28.35 sec
Epoch 9



Epoch 9/10000 train: 100%|██████████| 100/100 [00:21<00:00,  4.70it/s]

train loss: 1.248510
train acc: 0.492667



validation: 100%|██████████| 100/100 [00:06<00:00, 14.66it/s]

validation loss: 1.335311
validation acc: 0.446533
best model saved
time: 28.25 sec
Epoch 10



Epoch 10/10000 train: 100%|██████████| 100/100 [00:21<00:00,  4.71it/s]

train loss: 1.237673
train acc: 0.502400



validation: 100%|██████████| 100/100 [00:06<00:00, 14.56it/s]

validation loss: 1.306427
validation acc: 0.464800
best model saved
time: 28.20 sec
Epoch 11



Epoch 11/10000 train: 100%|██████████| 100/100 [00:21<00:00,  4.70it/s]

train loss: 1.215767
train acc: 0.516133



validation: 100%|██████████| 100/100 [00:06<00:00, 14.51it/s]

validation loss: 1.296852
validation acc: 0.468533
best model saved
time: 28.27 sec
Epoch 12



Epoch 12/10000 train: 100%|██████████| 100/100 [00:21<00:00,  4.70it/s]

train loss: 1.210047
train acc: 0.518400



validation: 100%|██████████| 100/100 [00:06<00:00, 14.52it/s]

validation loss: 1.324509
validation acc: 0.451067
time: 28.23 sec
Epoch 13



Epoch 13/10000 train: 100%|██████████| 100/100 [00:21<00:00,  4.69it/s]

train loss: 1.209162
train acc: 0.515333



validation: 100%|██████████| 100/100 [00:06<00:00, 14.46it/s]

validation loss: 1.313265
validation acc: 0.455200
time: 28.28 sec
Epoch 14



Epoch 14/10000 train: 100%|██████████| 100/100 [00:21<00:00,  4.69it/s]

train loss: 1.195520
train acc: 0.525200



validation: 100%|██████████| 100/100 [00:06<00:00, 14.49it/s]

validation loss: 1.295556
validation acc: 0.465733
best model saved
time: 28.31 sec
Epoch 15



Epoch 15/10000 train: 100%|██████████| 100/100 [00:21<00:00,  4.69it/s]

train loss: 1.180341
train acc: 0.523467



validation: 100%|██████████| 100/100 [00:06<00:00, 14.48it/s]

validation loss: 1.293404
validation acc: 0.472400
best model saved
time: 28.33 sec
Epoch 16



Epoch 16/10000 train: 100%|██████████| 100/100 [00:21<00:00,  4.68it/s]

train loss: 1.179506
train acc: 0.535733



validation: 100%|██████████| 100/100 [00:06<00:00, 14.45it/s]

validation loss: 1.278445
validation acc: 0.474000
best model saved
time: 28.35 sec
Epoch 17



Epoch 17/10000 train: 100%|██████████| 100/100 [00:21<00:00,  4.69it/s]

train loss: 1.156710
train acc: 0.540267



validation: 100%|██████████| 100/100 [00:06<00:00, 14.52it/s]

validation loss: 1.274048
validation acc: 0.476933
best model saved
time: 28.32 sec
Epoch 18



Epoch 18/10000 train: 100%|██████████| 100/100 [00:21<00:00,  4.69it/s]

train loss: 1.159234
train acc: 0.542533



validation: 100%|██████████| 100/100 [00:06<00:00, 14.48it/s]

validation loss: 1.237698
validation acc: 0.500533
best model saved
time: 28.31 sec
Epoch 19



Epoch 19/10000 train: 100%|██████████| 100/100 [00:21<00:00,  4.69it/s]

train loss: 1.159303
train acc: 0.537333



validation: 100%|██████████| 100/100 [00:06<00:00, 14.50it/s]

validation loss: 1.295219
validation acc: 0.491067
time: 28.29 sec
Epoch 20



Epoch 20/10000 train: 100%|██████████| 100/100 [00:21<00:00,  4.69it/s]

train loss: 1.130165
train acc: 0.556400



validation: 100%|██████████| 100/100 [00:06<00:00, 14.53it/s]

validation loss: 1.279875
validation acc: 0.473733
time: 28.27 sec
Epoch 21



Epoch 21/10000 train: 100%|██████████| 100/100 [00:21<00:00,  4.69it/s]

train loss: 1.116047
train acc: 0.566533



validation: 100%|██████████| 100/100 [00:06<00:00, 14.51it/s]

validation loss: 1.237519
validation acc: 0.502667
best model saved
time: 28.31 sec
Epoch 22



Epoch 22/10000 train: 100%|██████████| 100/100 [00:21<00:00,  4.69it/s]

train loss: 1.096828
train acc: 0.570933



validation: 100%|██████████| 100/100 [00:06<00:00, 14.53it/s]

validation loss: 1.230772
validation acc: 0.512667
best model saved
time: 28.30 sec
Epoch 23



Epoch 23/10000 train: 100%|██████████| 100/100 [00:21<00:00,  4.69it/s]

train loss: 1.138094
train acc: 0.553200



validation: 100%|██████████| 100/100 [00:06<00:00, 14.53it/s]

validation loss: 1.215932
validation acc: 0.510533
best model saved
time: 28.29 sec
Epoch 24



Epoch 24/10000 train: 100%|██████████| 100/100 [00:21<00:00,  4.68it/s]

train loss: 1.096061
train acc: 0.573733



validation: 100%|██████████| 100/100 [00:06<00:00, 14.56it/s]

validation loss: 1.213117
validation acc: 0.511733
best model saved
time: 28.31 sec
Epoch 25



Epoch 25/10000 train: 100%|██████████| 100/100 [00:21<00:00,  4.68it/s]

train loss: 1.077407
train acc: 0.580533



validation: 100%|██████████| 100/100 [00:06<00:00, 14.46it/s]

validation loss: 1.231584
validation acc: 0.504000
time: 28.32 sec
Epoch 26



Epoch 26/10000 train: 100%|██████████| 100/100 [00:21<00:00,  4.67it/s]

train loss: 1.122819
train acc: 0.556000



validation: 100%|██████████| 100/100 [00:07<00:00, 14.25it/s]

validation loss: 1.253557
validation acc: 0.496667
time: 28.48 sec
Epoch 27



Epoch 27/10000 train: 100%|██████████| 100/100 [00:21<00:00,  4.67it/s]

train loss: 1.075812
train acc: 0.585067



validation: 100%|██████████| 100/100 [00:06<00:00, 14.35it/s]

validation loss: 1.192590
validation acc: 0.518933
best model saved
time: 28.46 sec
Epoch 28



Epoch 28/10000 train: 100%|██████████| 100/100 [00:21<00:00,  4.68it/s]

train loss: 1.101419
train acc: 0.565067



validation: 100%|██████████| 100/100 [00:06<00:00, 14.52it/s]

validation loss: 1.209191
validation acc: 0.520000
time: 28.30 sec
Epoch 29



Epoch 29/10000 train: 100%|██████████| 100/100 [00:21<00:00,  4.69it/s]

train loss: 1.067651
train acc: 0.577200



validation: 100%|██████████| 100/100 [00:06<00:00, 14.49it/s]

validation loss: 1.247000
validation acc: 0.501600
time: 28.31 sec
Epoch 30



Epoch 30/10000 train: 100%|██████████| 100/100 [00:21<00:00,  4.68it/s]

train loss: 1.078936
train acc: 0.575067



validation: 100%|██████████| 100/100 [00:06<00:00, 14.59it/s]

validation loss: 1.214851
validation acc: 0.509467
time: 28.27 sec
Epoch 31



Epoch 31/10000 train: 100%|██████████| 100/100 [00:21<00:00,  4.69it/s]

train loss: 1.086512
train acc: 0.577600



validation: 100%|██████████| 100/100 [00:06<00:00, 14.45it/s]

validation loss: 1.261733
validation acc: 0.493867
time: 28.29 sec
Epoch 32



Epoch 32/10000 train: 100%|██████████| 100/100 [00:21<00:00,  4.69it/s]

train loss: 1.041196
train acc: 0.597600



validation: 100%|██████████| 100/100 [00:06<00:00, 14.57it/s]

validation loss: 1.255529
validation acc: 0.506933
time: 28.25 sec
Epoch 33



Epoch 33/10000 train: 100%|██████████| 100/100 [00:21<00:00,  4.69it/s]

train loss: 1.056150
train acc: 0.588533



validation: 100%|██████████| 100/100 [00:06<00:00, 14.56it/s]

validation loss: 1.214480
validation acc: 0.509200
time: 28.25 sec
Epoch 34



Epoch 34/10000 train: 100%|██████████| 100/100 [00:21<00:00,  4.69it/s]

train loss: 1.059453
train acc: 0.591733



validation: 100%|██████████| 100/100 [00:06<00:00, 14.51it/s]

validation loss: 1.178192
validation acc: 0.532933
best model saved
time: 28.30 sec
Epoch 35



Epoch 35/10000 train: 100%|██████████| 100/100 [00:21<00:00,  4.69it/s]

train loss: 1.050289
train acc: 0.588667



validation: 100%|██████████| 100/100 [00:06<00:00, 14.56it/s]

validation loss: 1.200378
validation acc: 0.525600
time: 28.25 sec
Epoch 36



Epoch 36/10000 train: 100%|██████████| 100/100 [00:21<00:00,  4.69it/s]

train loss: 1.045259
train acc: 0.596133



validation: 100%|██████████| 100/100 [00:06<00:00, 14.54it/s]

validation loss: 1.169670
validation acc: 0.526533
best model saved
time: 28.28 sec
Epoch 37



Epoch 37/10000 train: 100%|██████████| 100/100 [00:21<00:00,  4.69it/s]

train loss: 1.044351
train acc: 0.594400



validation: 100%|██████████| 100/100 [00:06<00:00, 14.52it/s]

validation loss: 1.182223
validation acc: 0.527467
time: 28.27 sec
Epoch 38



Epoch 38/10000 train: 100%|██████████| 100/100 [00:21<00:00,  4.68it/s]

train loss: 1.035925
train acc: 0.601467



validation: 100%|██████████| 100/100 [00:06<00:00, 14.55it/s]

validation loss: 1.174496
validation acc: 0.531200
time: 28.28 sec
Epoch 39



Epoch 39/10000 train: 100%|██████████| 100/100 [00:21<00:00,  4.69it/s]

train loss: 0.993015
train acc: 0.614933



validation: 100%|██████████| 100/100 [00:06<00:00, 14.45it/s]

validation loss: 1.186807
validation acc: 0.521600
time: 28.32 sec
Epoch 40



Epoch 40/10000 train: 100%|██████████| 100/100 [00:21<00:00,  4.68it/s]

train loss: 1.011571
train acc: 0.606133



validation: 100%|██████████| 100/100 [00:06<00:00, 14.36it/s]

validation loss: 1.161065
validation acc: 0.536667
best model saved
time: 28.44 sec
Epoch 41



Epoch 41/10000 train:  65%|██████▌   | 65/100 [00:13<00:07,  4.65it/s]

In [None]:
from google.colab import drive
drive.mount('/content/drive')