In [4]:
"""
Authors:
Nawras Abbas    315085043
Michael Bikman  317920999
"""
import torch
from torch import nn
from torch.distributions import Bernoulli
import torch.nn.functional as F
import torch.nn as nn
import os
import numpy as np
import torch.nn
from torch.autograd import Variable
from torch.nn import functional
import logging
import os
import sys
from datetime import datetime
import math
import random
import time
import torch
from tqdm import tqdm
import pickle

# current time
TIMESTAMP = datetime.now().strftime("%Y_%m_%d-%H_%M_%S")

DATA_DIR = '/content/drive/MyDrive/Colab Notebooks/project'
TRAIN_DATA_FILE = r"mini-imagenet-cache-train.pkl"
VALID_DATA_FILE = r"mini-imagenet-cache-val.pkl"
TEST_DATA_FILE = r"mini-imagenet-cache-test.pkl"

DEVICE = torch.device('cpu')
if torch.cuda.is_available():
    DEVICE = torch.device('cuda')


SEED = 4444
USE_SEED = True
LOG_ENABLED = True
MODEL_FILE = '/content/drive/MyDrive/Colab Notebooks/project/best_model.pt'  # must be in the same directory as code

log = print
if LOG_ENABLED:
    log = logging.info





# ********************************************* resnet12 *********************************************
# ********************************************* resnet12 *********************************************
# ********************************************* resnet12 *********************************************
# ********************************************* resnet12 *********************************************
# ********************************************* resnet12 *********************************************


class DropBlock(nn.Module):
    def __init__(self, block_size):
        super(DropBlock, self).__init__()
        self.block_size = block_size

    def forward(self, x, gamma):
        # shape: (bsize, channels, height, width)

        if self.training:
            batch_size, channels, height, width = x.shape

            bernoulli = Bernoulli(gamma)
            mask = bernoulli.sample(
                (batch_size, channels, height - (self.block_size - 1), width - (self.block_size - 1))).to(DEVICE)
            block_mask = self._compute_block_mask(mask)
            countM = block_mask.size()[0] * block_mask.size()[1] * block_mask.size()[2] * block_mask.size()[3]
            count_ones = block_mask.sum()

            return block_mask * x * (countM / count_ones)
        else:
            return x

    def _compute_block_mask(self, mask):
        left_padding = int((self.block_size - 1) / 2)
        right_padding = int(self.block_size / 2)

        non_zero_idxs = mask.nonzero()
        nr_blocks = non_zero_idxs.shape[0]

        offsets = torch.stack(
            [
                torch.arange(self.block_size).view(-1, 1).expand(self.block_size, self.block_size).reshape(-1),
                torch.arange(self.block_size).repeat(self.block_size),  # - left_padding
            ]
        ).t().to(DEVICE)
        offsets = torch.cat((torch.zeros(self.block_size ** 2, 2).to(DEVICE).long(), offsets.long()), 1)

        if nr_blocks > 0:
            non_zero_idxs = non_zero_idxs.repeat(self.block_size ** 2, 1)
            offsets = offsets.repeat(nr_blocks, 1).view(-1, 4)
            offsets = offsets.long()

            block_idxs = non_zero_idxs + offsets
            padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding))
            padded_mask[block_idxs[:, 0], block_idxs[:, 1], block_idxs[:, 2], block_idxs[:, 3]] = 1.
        else:
            padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding))

        block_mask = 1 - padded_mask  # [:height, :width]
        return block_mask


# This ResNet network was designed following the practice of the following papers:
# TADAM: Task dependent adaptive metric for improved few-shot learning (Oreshkin et al., in NIPS 2018) and
# A Simple Neural Attentive Meta-Learner (Mishra et al., in ICLR 2018).

def conv3x3(in_planes, out_planes, stride=1):
    """
    3x3 convolution with padding
    """
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


class ResNet12_BasicBlock(nn.Module):

    def __init__(self, in_planes, planes, stride=1, down_sample=None, drop_rate=0.0, drop_block=False, block_size=1):
        super(ResNet12_BasicBlock, self).__init__()
        self.conv1 = conv3x3(in_planes, planes)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.LeakyReLU(0.1, inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = conv3x3(planes, planes)
        self.bn3 = nn.BatchNorm2d(planes)
        self.max_pool = nn.MaxPool2d(stride)
        self.down_sample = down_sample
        self.stride = stride
        self.drop_rate = drop_rate
        self.num_batches_tracked = 0
        self.drop_block = drop_block
        self.block_size = block_size
        self.DropBlock = DropBlock(block_size=self.block_size)

    def forward(self, x):
        self.num_batches_tracked += 1

        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.down_sample is not None:
            residual = self.down_sample(x)
        out += residual
        out = self.relu(out)
        out = self.max_pool(out)

        if self.drop_rate > 0:
            if self.drop_block:
                feat_size = out.size()[2]
                keep_rate = max(1.0 - self.drop_rate / (20 * 2000) * self.num_batches_tracked, 1.0 - self.drop_rate)
                gamma = (1 - keep_rate) / self.block_size ** 2 * feat_size ** 2 / (feat_size - self.block_size + 1) ** 2
                out = self.DropBlock(out, gamma=gamma)

        return out


class ResNet(nn.Module):

    def __init__(self, block, drop_rate, drop_block_size=5):
        self.in_planes = 3
        super(ResNet, self).__init__()

        filters = [64, 64, 64, 64]

        self.layer1 = self._make_layer(block, filters[0], stride=2, drop_rate=drop_rate)
        self.layer2 = self._make_layer(block, filters[1], stride=2, drop_rate=drop_rate)
        self.layer3 = self._make_layer(block, filters[2], stride=2, drop_rate=drop_rate, drop_block=True,
                                       block_size=drop_block_size)
        self.layer4 = self._make_layer(block, filters[3], stride=2, drop_rate=drop_rate, drop_block=True,
                                       block_size=drop_block_size)

        self.drop_rate = drop_rate
        self.flatten = nn.Sequential(
            nn.Flatten(start_dim=1)
        )
  

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, stride=1, drop_rate=0.0, drop_block=False, block_size=1):
        down_sample = None
        if stride != 1 or self.in_planes != planes:
            down_sample = nn.Sequential(
                nn.Conv2d(self.in_planes, planes,
                          kernel_size=1, stride=1, bias=False),
                nn.BatchNorm2d(planes),
            )
        layers = [block(self.in_planes, planes, stride, down_sample, drop_rate, drop_block, block_size)]
        self.in_planes = planes

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.flatten(x)
        return x


def resnet12(drop_rate):
    """
    Constructs a ResNet-12 model
    """
    return ResNet(ResNet12_BasicBlock, drop_rate)





# ********************************************* resnet18 *********************************************
# ********************************************* resnet18 *********************************************
# ********************************************* resnet18 *********************************************
# ********************************************* resnet18 *********************************************
# ********************************************* resnet18 *********************************************




class ResNet18_BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResNet18_BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.Sequential(
            nn.ReLU(inplace=True),
        )
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += residual
        out = self.relu(out)
        return out


class ResNet18_BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResNet18_BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=2, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.Sequential(
            nn.ReLU(inplace=True),
        )
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, bias=False, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.down_sample = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2, bias=False),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        residual = self.down_sample(x)
        out += residual
        out = self.relu(out)
        return out


class ResNet18(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.Sequential(
            nn.ReLU(inplace=True),
        )
        self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = nn.Sequential(
            ResNet18_BasicBlock(64, 64),
            ResNet18_BasicBlock(64, 64)
        )
        self.layer2 = nn.Sequential(
            ResNet18_BasicBlock(64, 128),
            ResNet18_BasicBlock(128, 128)
        )
        self.layer3 = nn.Sequential(
            ResNet18_BasicBlock(128, 256),
            ResNet18_BasicBlock(256, 256)
        )
        self.layer4 = nn.Sequential(
            ResNet18_BasicBlock(256, 512),
            ResNet18_BasicBlock(512, 512)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(start_dim=1)
        )

    def forward(self, x):  # (200,3,84,84)
        x = self.conv1(x)  # (200,64,42,42)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.max_pool(x)  # (200,64,21,21)

        x = self.layer1(x)  # out (200,64,21,21)
        x = self.layer2(x)  # out (200,128,11,11)
        x = self.layer3(x)  # out (200,256,6,6)
        x = self.layer4(x)  # out (200,512,3,3)
        x = self.classifier(x)  # out (200, 4608)

        return x


# ********************************************* models *********************************************
# ********************************************* models *********************************************
# ********************************************* models *********************************************
# ********************************************* models *********************************************
# ********************************************* models *********************************************




def euclidean_dist(x, y):
    """
    TBD
    :param x: size [n_query_total, out_dim=1600] - queries
    :param y: size [n_ways, out_dim=1600] - prototypes
    :return:
    """
    n = x.size(0)  # total number of query points = n_query_total
    m = y.size(0)  # number of classes = n_ways
    d = x.size(1)  # dimension of pic embedding = 1600 for mini-ImageNet
    if d != y.size(1):
        raise ValueError(f'Pic embedding for prototype {y.size(1)} and query {d} data arent equal')

    x = x.unsqueeze(1).expand(n, m, d)  # size = [n_query_total, n_ways, 1600]
    y = y.unsqueeze(0).expand(n, m, d)  # size = [n_query_total, n_ways, 1600]

    return torch.pow(x - y, 2).sum(2)


def mahalanobis_dist(x, y):
    """
    TBD
    :param x: size [n_query_total, out_dim=1600] - queries
    :param y: size [n_ways, out_dim=1600] - prototypes
    :return:
    """
    n_queries = 15
    n_query_total = x.size(0)
    n_ways = y.size(0)  # number of classes = n_ways
    res = torch.zeros(n_query_total, n_ways).to(DEVICE)  # size = [n_query_total, n_ways]
    queries_per_class = x.split(n_queries, dim=0)  # (10, [15, 1600])
    prototypes_per_class = y.split(1, dim=0)  # (10, [1,1600])
    batches = int(n_query_total / n_queries)
    for class_ndx in range(n_ways):
        class_queries = queries_per_class[class_ndx].detach().cpu()
        proto = prototypes_per_class[class_ndx]
        for query_batch_ndx in range(batches):
            query_batch = queries_per_class[query_batch_ndx]
            cov_arr = np.cov(class_queries.T)
            cov = torch.from_numpy(cov_arr).to(DEVICE)
            cov_diag = torch.diag(cov)
            cov = torch.diag(cov_diag)

            for query_ndx in range(n_queries):
                query = query_batch[query_ndx, :]
                dist = mahalanobis(proto, query, cov)
                q = n_queries * query_batch_ndx + query_ndx
                res[q, class_ndx] = dist.item()
    return res


def mahalanobis(u, v, cov):
    delta = (u - v).double()
    delta_trans = torch.transpose(delta, 0, 1).double()
    cov_inverse = torch.inverse(cov).double()
    mult = torch.matmul(delta, cov_inverse)
    m = torch.matmul(mult, delta_trans)
    return torch.sqrt(m)


class ProtoNetSimple(nn.Module):
    def __init__(self, num_filters=128):
        super().__init__()

        self.file_name = 'ProtoNetSimple.pt'

        self.block1 = self._cnn_block(3, num_filters)
        self.block2 = self._cnn_block(num_filters, num_filters)
        self.block3 = self._cnn_block(num_filters, num_filters)
        self.block4 = self._cnn_block(num_filters, num_filters)

        self.classifier = nn.Sequential(
            nn.Flatten(start_dim=1)
        )

        self.alpha = nn.Parameter(torch.tensor(1.0, requires_grad=True))

    @staticmethod
    def _cnn_block(in_channels, out_channels):
        block = torch.nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2)
        )
        return block

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        vector = self.classifier(x)
        return vector

    def loss(self, episode, image_data):
        support_data = episode.get_support_sample(image_data)
        query_data = episode.get_query_sample(image_data)

        xs = Variable(support_data)
        # xs size = [n_ways, n_shots, channels=3, width=84, height=84]
        xq = Variable(query_data)
        # xq size = [n_ways, n_query_points, channels=3, width=84, height=84]

        n_class = xs.size(0)
        if xq.size(0) != n_class:
            raise ValueError(f'Number of classes for support {xs.size(0)} and query {xq.size(0)} data is not equal')
        n_support = xs.size(1)
        n_query = xq.size(1)

        target_indices = torch.arange(0, n_class).view(n_class, 1, 1).expand(n_class, n_query, 1).long().to(DEVICE)
        target_indices = Variable(target_indices, requires_grad=False)

        support_pic_size = xs.size()[2:]
        query_pic_size = xq.size()[2:]
        if query_pic_size != support_pic_size:
            raise ValueError(f'Pic sizes for support {support_pic_size} and query {query_pic_size} data arent equal')
        n_support_total = n_class * n_support
        n_query_total = n_class * n_query
        xs_view = xs.view(n_support_total, *support_pic_size)
        xq_view = xq.view(n_query_total, *query_pic_size)
        x = torch.cat([xs_view, xq_view], 0).float().to(DEVICE)  # input for the model
        # x = has dimension of [n_support_total + n_query_total, channels=3, width=84, height=84]

        z = self.forward(x)  # output with dimension [n_support_total + n_query_total, 1600]
        z_dim = z.size(-1)

        z_support = z[:n_support_total]  # size = [n_support_total, 1600]
        # prototype = average all embeddings from support set
        prototypes_per_class = z_support.view(n_class, n_support, z_dim).mean(1)  # size = [n_class, 1600]
        query_vectors = z[n_support_total:]  # size = [n_query_total, 1600]

        dists_per_class = euclidean_dist(query_vectors, prototypes_per_class)  # size = [n_query_total, n_ways]
        # dists_per_class = mahalanobis_dist(query_vectors, prototypes_per_class)  # size = [n_query_total, n_ways]

        # alpha used here ----------------------------------------------------------
        # alpha parameter is used to scale the distance metric to obtain better results
        if self.alpha is not None:
            dists_per_class = torch.mul(self.alpha, dists_per_class)
        # --------------------------------------------------------------------------

        log_p_y = torch.nn.functional.log_softmax(-dists_per_class, dim=1).view(n_class, n_query, -1)  # log(p(y=k|x))
        # log_p_y = size [n_class = n_ways, n_query, n_class]
        loss_per_query = -log_p_y.gather(2, target_indices).squeeze().view(-1)  # size = [n_query_total]
        loss_val = loss_per_query.mean()  # average loss for all queries
        _, y_hat = log_p_y.max(dim=2)  # returns tuple (max values, argmax indices)
        # y_hat size = [n_class, n_query]

        # calculate accuracy = number of matches between y_hat indices and ground truth target_indices
        acc_val = torch.eq(y_hat, target_indices.squeeze()).float().mean()

        return loss_val, {
            'loss': loss_val.item(),
            'acc': acc_val.item()
        }

    def save(self, report_folder):
        model_file = os.path.join(report_folder, self.file_name)
        torch.save(self.state_dict(), model_file)


class ProtoNetComplex(ProtoNetSimple):

    def __init__(self, num_filters=64):
        super().__init__(num_filters)

        self.file_name = 'ProtoNetComplex.pt'

        self.block1 = self._cnn_block(3, num_filters)
        self.block2 = self._cnn_block(num_filters, num_filters)
        self.block3 = self._cnn_block(num_filters, num_filters)
        self.block4 = self._cnn_block(num_filters, num_filters)

        self.classifier = nn.Sequential(
            nn.Flatten(start_dim=1),
            nn.Linear(1600, 1600),
            nn.Dropout(0.5),
            nn.ReLU(),
            nn.Linear(1600, 128)
        )


class ProtoNetRes18(ProtoNetSimple):
    def __init__(self):
        super().__init__()
        self.res_net = ResNet18()
        self.file_name = 'ProtoNetRes18.pt'

    def forward(self, x):
        x = self.res_net.forward(x)
        return x


class ProtoNetRes12(ProtoNetSimple):
    def __init__(self, drop_rate=0.1):
        super().__init__()
        self.res_net = resnet12(drop_rate=drop_rate)
        self.file_name = 'ProtoNetRes12.pt'

    def forward(self, x):
        x = self.res_net.forward(x)
        return x


class ResNetSimple(ProtoNetSimple):
    """
    Densely Connected CNN
    """

    def __init__(self, num_filters=64):
        super().__init__()
        self.file_name = 'ResNetSimple.pt'
        self.maxPool = nn.MaxPool2d(kernel_size=2)
        self.alpha = nn.Parameter(torch.tensor(1.0, requires_grad=True))
        self.relu = nn.ReLU(inplace=True)
        self.flatten = nn.Sequential(
            nn.Flatten(start_dim=1)
        )
        self.BN = nn.BatchNorm2d(num_filters)

        self.block1 = self._cnn_block(3, num_filters)
        self.blockRes1 = self._cnn_block_Res(3, num_filters)

        self.block2 = self._cnn_block(num_filters, num_filters)
        self.blockRes2 = self._cnn_block_Res(num_filters, num_filters)

    @staticmethod
    def _cnn_block(in_channels, out_channels):
        block = torch.nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
        )
        return block

    @staticmethod
    def _cnn_block_Res(in_channels, out_channels):
        block = torch.nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels)
        )
        return block

    def forward(self, x):
        # **************************   block 1   *********************************
        # on the original input of the block we run a conv layer and Batch normalization
        residual = x
        residual = self.blockRes1(residual)

        # on the original input run conv -> relu -> conv -> relu -> conv
        # and then sum the outputs of the above two blocks and then run relu and MP
        # run the same architecture on the other 3 blocks
        x = self.block1(x)
        x = x + residual
        x = self.relu(x)
        x = self.maxPool(x)

        # **************************   block 2   *********************************
        residual = x
        residual = self.blockRes2(residual)
        x = self.block2(x)
        x = x + residual
        x = self.relu(x)
        x = self.maxPool(x)

        # **************************   block 3   *********************************
        residual = x
        residual = self.blockRes2(residual)
        x = self.block2(x)
        x = x + residual
        x = self.relu(x)
        x = self.maxPool(x)

        # **************************   block 4   *********************************
        residual = x
        residual = self.blockRes2(residual)
        x = self.block2(x)
        x = x + residual
        x = self.relu(x)
        x = self.maxPool(x)

        vector = self.flatten(x)
        return vector






# ********************************************* parameters *********************************************
# ********************************************* parameters *********************************************
# ********************************************* parameters *********************************************
# ********************************************* parameters *********************************************
# ********************************************* parameters *********************************************




class Episode(object):
    """
    Class representation for a single episode
    """

    def __init__(self):
        self.indexes = {}  # class_name -> (support pics indices, query pics indices)
        self.SUPPORT_DATA_NDX = 0
        self.QUERY_DATA_NDX = 1

    def get_support_sample(self, image_data):
        return self._get_data_sample(image_data, self.SUPPORT_DATA_NDX)

    def get_query_sample(self, image_data):
        return self._get_data_sample(image_data, self.QUERY_DATA_NDX)

    def add_indices(self, class_name, support_ndxs, query_ndxs):
        self.indexes[class_name] = (support_ndxs, query_ndxs)

    def _get_data_sample(self, image_data, data_index):
        all_samples = []
        for class_name in self.indexes.keys():
            ndxs = self.indexes[class_name][data_index]
            sample_indices = torch.tensor(ndxs)
            sample = torch.index_select(image_data, 0, sample_indices)
            sample = sample.transpose(1, 3).transpose(2, 3)
            all_samples.append(sample)
        result = torch.stack(all_samples, dim=0)
        return result


class RunParameters(object):
    """
    Use this class for train + validation run
    """

    def __init__(self):
        self.model = None
        self.epochs = 0
        self.loss = None
        self.optimizer = None
        self.scheduler = None
        self.report_path = None  # path where to save PT file and logs
        self.train_data = None
        self.val_data = None
        self.episodes_per_epoch = 0
        self.patience = 0
        self.n_ways = 0
        self.n_support_examples = 0
        self.n_query_examples = 0

    def __str__(self):
        return '--------Run params--------\n' \
               f'Epochs: {self.epochs}\n' \
               f'Loss: {self.loss}\n' \
               f'Optimizer:{self.optimizer}\n' \
               f'Scheduler:{self.scheduler}\n' \
               f'report_path:{self.report_path}\n' \
               f'episodes_per_epoch:{self.episodes_per_epoch}\n' \
               f'patience:{self.patience}\n' \
               f'n_ways:{self.n_ways}\n' \
               f'n_support_examples:{self.n_support_examples}\n' \
               f'n_query_examples:{self.n_query_examples}\n' \
               '-----------------------------\n'


class TrainResult(object):
    """
    Result of train + validation run
    """

    def __init__(self):
        self.train_loss_per_epoch = []
        self.train_accuracy_per_epoch = []
        self.validation_loss_per_epoch = []
        self.validation_accuracy_per_epoch = []
        self.best_epoch = 0

    def train_loss_min(self):
        return min(self.train_loss_per_epoch)

    def valid_loss_min(self):
        return min(self.validation_loss_per_epoch)

    def train_accuracy(self):
        return max(self.train_accuracy_per_epoch)

    def validation_accuracy(self):
        return max(self.validation_accuracy_per_epoch)

    def __str__(self):
        return '--------Run result--------\n' \
               f'Train Loss: {self.train_loss_min()}\n' \
               f'Train Acc: {self.train_accuracy()}\n' \
               f'Valid Loss:{self.valid_loss_min()}\n' \
               f'Valid Acc:{self.validation_accuracy()}\n' \
               f'Best Epoch:{self.best_epoch}\n' \
               '-----------------------------\n'


class TestResult:
    """
    Used for test evaluation
    """

    def __init__(self):
        self.acc = 0  # accuracy
        self.loss = 0  # loss

    def __str__(self):
        return '--------Test result--------\n' \
               f'Test Acc: {self.acc}\n' \
               f'Test Loss:{self.loss}\n' \
               '-----------------------------\n'




# ********************************************* utils *********************************************
# ********************************************* utils *********************************************
# ********************************************* utils *********************************************
# ********************************************* utils *********************************************
# ********************************************* utils *********************************************



def setup_reports():
    """
    Create setup for logger - both to file and to console
    Also model will be saved in the same folder with logs
    """
    report_root = 'reports'
    if not os.path.exists(report_root):
        os.makedirs(report_root)
    report_folder = f'{TIMESTAMP}_report'
    report_path = os.path.join(report_root, report_folder)
    if not os.path.exists(report_path):
        os.makedirs(report_path)

    logging.getLogger('matplotlib.font_manager').disabled = True
    logger = logging.getLogger()
    logger.setLevel(logging.DEBUG)
    stdout_handler = logging.StreamHandler(sys.stdout)
    log_file = f'{TIMESTAMP}.log'
    output_file_handler = logging.FileHandler(f'{report_path}\{log_file}')
    logger.addHandler(output_file_handler)
    logger.addHandler(stdout_handler)
    return report_path





# ********************************************* engine *********************************************
# ********************************************* engine *********************************************
# ********************************************* engine *********************************************
# ********************************************* engine *********************************************
# ********************************************* engine *********************************************


def choose_episode_classes(class_dict, n_ways):
    """
    Choose which classes will participate in the episode
    :param class_dict: dictionary of all classes
    :param n_ways: number of classes per episode
    :return:
    """
    # select classes for episode (random uniform)
    all_classes = class_dict.keys()
    classes = random.sample(all_classes, n_ways)
    return classes


def create_episodes(class_dict, n_episodes, n_ways, n_supports, n_queries):
    """
    TBD
    :param class_dict:
    :param n_episodes:
    :param n_ways:
    :param n_supports:
    :param n_queries:
    :return:
    """
    episodes = []

    for e in range(n_episodes):
        try:
            classes = choose_episode_classes(class_dict, n_ways)
        except ValueError:
            continue

        episode = Episode()
        for c in classes:
            class_ndxs = class_dict[c]
            selected_ndxs = random.sample(class_ndxs, n_supports + n_queries)

            support_ndxs = selected_ndxs[:n_supports]
            query_ndxs = selected_ndxs[n_supports:]
            episode.add_indices(c, support_ndxs, query_ndxs)

        episodes.append(episode)

    return episodes


def train(parameters, log):
    """
    Run training procedure
    :param log: logger to file / console
    :param parameters: training parameters context
    :return: train result
    """
    log('Train started')
    train_res = TrainResult()

    best_loss = math.inf
    epochs_since_best = 0
    for ep in range(parameters.epochs):
        epoch_start_time = time.time()

        train_loop(ep, parameters, train_res, log)
        validation_loop(parameters, train_res, log, f"validation")
        if parameters.scheduler is not None:
            parameters.scheduler.step()

        # early stopping
        epoch_loss = train_res.validation_loss_per_epoch[-1].item()
        if epoch_loss < best_loss:
            epochs_since_best = 0
            best_loss = epoch_loss
            parameters.model.save(parameters.report_path)
            log('best model saved')
        else:
            epochs_since_best = epochs_since_best + 1
        if epochs_since_best > parameters.patience:
            # update result and get out
            train_res.best_epoch = ep - parameters.patience
            return train_res

        elapsed_time = time.time() - epoch_start_time
        log(f'time: {elapsed_time:5.2f} sec')

    return train_res


def train_loop(ep, parameters, train_res, log):
    """
    TBD
    :param ep:
    :param parameters:
    :param train_res:
    :param log:
    :return:
    """
    model = parameters.model
    optimizer = parameters.optimizer
    n_ways = parameters.n_ways
    n_episodes = parameters.episodes_per_epoch
    n_supports = parameters.n_support_examples
    n_queries = parameters.n_query_examples
    train_image_data = torch.from_numpy(parameters.train_data['image_data'])
    train_class_dict = parameters.train_data['class_dict']

    model.train()
    train_batch_losses = []
    train_batch_accuracies = []
    train_episodes = create_episodes(train_class_dict, n_episodes, n_ways, n_supports, n_queries)
    log(f'Epoch {ep + 1}')
    for episode in tqdm(train_episodes, desc=f"Epoch {ep + 1}/{parameters.epochs} train"):
        optimizer.zero_grad()
        loss, output = model.loss(episode, train_image_data)
        train_batch_losses.append(output['loss'])
        train_batch_accuracies.append(output['acc'])
        loss.backward()
        optimizer.step()
    epoch_train_loss = torch.mean(torch.tensor(train_batch_losses))
    epoch_train_acc = torch.mean(torch.tensor(train_batch_accuracies))
    log(f'train loss: {epoch_train_loss:.6f}')
    log(f'train acc: {epoch_train_acc:.6f}')
    # update train losses in result here
    train_res.train_accuracy_per_epoch.append(epoch_train_acc)
    train_res.train_loss_per_epoch.append(epoch_train_loss)


def validation_loop(parameters, train_res, log, desc):
    """
    TBD
    :param desc:
    :param parameters:
    :param train_res:
    :param log:
    :return:
    """
    model = parameters.model
    n_ways = parameters.n_ways
    n_episodes = parameters.episodes_per_epoch
    n_supports = parameters.n_support_examples
    n_queries = parameters.n_query_examples
    val_image_data = torch.from_numpy(parameters.val_data['image_data'])
    val_class_dict = parameters.val_data['class_dict']

    model.eval()
    val_batch_losses = []
    val_batch_accuracies = []
    with torch.no_grad():
        val_episodes = create_episodes(val_class_dict, n_episodes, n_ways, n_supports, n_queries)
        for episode in tqdm(val_episodes, desc=desc):
            _, output = model.loss(episode, val_image_data)
            val_batch_losses.append(output['loss'])
            val_batch_accuracies.append(output['acc'])
    epoch_val_loss = torch.mean(torch.tensor(val_batch_losses))
    epoch_val_acc = torch.mean(torch.tensor(val_batch_accuracies))
    log(f'{desc} loss: {epoch_val_loss:.6f}')
    log(f'{desc} acc: {epoch_val_acc:.6f}')
    # update losses and accuracies in result here
    train_res.validation_accuracy_per_epoch.append(epoch_val_acc)
    train_res.validation_loss_per_epoch.append(epoch_val_loss)


def test(parameters, log):
    """
    TBD
    :param parameters:
    :param log:
    :return:
    """
    log('Test started')
    test_res = TrainResult()
    epoch_start_time = time.time()
    validation_loop(parameters, test_res, log, f"Test")
    elapsed_time = time.time() - epoch_start_time
    log(f'time: {elapsed_time:5.2f} sec')
    result = TestResult()
    result.acc = test_res.validation_accuracy_per_epoch[-1].item()
    result.loss = test_res.validation_loss_per_epoch[-1].item()
    return result




# ********************************************* run_eval *********************************************
# ********************************************* run_eval *********************************************
# ********************************************* run_eval *********************************************
# ********************************************* run_eval *********************************************
# ********************************************* run_eval *********************************************





def load_test_data():
    """
    Load test evaluation data from Pickle file
    :return: data dictionary
    """
    with open(os.path.join(DATA_DIR, TEST_DATA_FILE), "rb") as data_file:
        data = pickle.load(data_file)
    log(f'Loaded test data: {TEST_DATA_FILE}')
    return data


def main():
    """
    TBD
    :return:
    """
    setup_reports()
    log(DEVICE)
    log(f'SEED:{SEED}')

    model = ProtoNetRes12().to(DEVICE)

    log(model)
    
    if not os.path.isfile(MODEL_FILE):
        raise Exception(f'Cannot train model - no saved state found: {MODEL_FILE}!')
    state_dict = torch.load(MODEL_FILE)
    model.load_state_dict(state_dict)

    log(f'Alpha:{list(model.parameters())[0].item()}')

    log('Model load successful!')
    test_data = load_test_data()

    test_params = RunParameters()
    test_params.model = model
    test_params.val_data = test_data

    test_params.n_ways = 5  # ways
    test_params.episodes_per_epoch = 100
    test_params.n_support_examples = 5  # shots
    test_params.n_query_examples = 15
    log(test_params)

    if USE_SEED:
        random.seed(SEED)

    test_res = test(test_params, log)
    log(test_res)


In [5]:
if __name__ == '__main__':
    main()
    print('OK')

cuda
cuda
SEED:4444
SEED:4444
ProtoNetRes12(
  (block1): Sequential(
    (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (block2): Sequential(
    (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (block3): Sequential(
    (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (block4): Sequential(
    (0): Con

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
Test: 100%|██████████| 100/100 [00:06<00:00, 15.17it/s]


Test loss: 0.735402
Test loss: 0.735402
Test acc: 0.710133
Test acc: 0.710133
time:  6.63 sec
time:  6.63 sec
--------Test result--------
Test Acc: 0.7101333141326904
Test Loss:0.7354016304016113
-----------------------------

--------Test result--------
Test Acc: 0.7101333141326904
Test Loss:0.7354016304016113
-----------------------------

OK


In [6]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
