In [None]:
import os
import re
import pdb
import glob
import pickle

import torch
import torch.nn as nn

import torch.utils.data as data
import torchvision.transforms as transforms

import PIL.Image as Image
import numpy as np
import pandas as pd
import random
import logging

In [None]:
from typing import Optional, Sized


class FewShotDataset:
    def __init__(self, root, phase, n_shot, n_eval, transform=None):

        self.root = os.path.join(root, phase)
        self.labels = sorted(os.listdir(self.root))
        self.n_shot = n_shot
        self.n_eval = n_eval
        self.transform = transform

    def __getitem__(self, idx):
        # Get data for a single class
        class_dir = os.path.join(self.root, self.labels[idx])
        class_data = [(Image.open(os.path.join(class_dir, fname)), idx) for fname in os.listdir(class_dir)]

        # Separate training and evaluation data
        n_total = len(class_data)
        n_train = self.n_shot
        n_eval = self.n_eval
        train_data = class_data[:n_train]
        eval_data = class_data[n_train:n_train+n_eval]

        # Apply transformations
        if self.transform is not None:
            train_data = [(self.transform(x), y) for x, y in train_data]
            eval_data = [(self.transform(x), y) for x, y in eval_data]

        # Combine training and evaluation data
        episode_data = train_data + eval_data

        # Create episode batch
        episode_x = torch.stack([x for x, y in episode_data], dim=0)
        episode_y = torch.LongTensor([y for x, y in episode_data])

        return episode_x, episode_y

    def __len__(self):
        return len(self.labels)
    

class ClassSet(data.Dataset):

    def __init__(self, images, label, transform=None):
        self.images = images
        self.label = label
        self.transform = transform

    def __getitem__(self, idx):
        with open(self.images[idx], 'rb') as f:
            image = Image.open(f).convert('RGB')

        return image, self.label

    def __len__(self):
        return len(self.images)

def process_data(data):
    normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.229, 0.224, 0.225])
    transform1 = transforms.Compose([
        transforms.RandomResizedCrop(dat['image_size']),
        transforms.ToTensor(),
        normalize
    ])
    train_set = Episode(dat['data_root'], 'train', dat['n_shot'], dat['n_eval'],transform = transform1)
    validation_set = Episode(dat['data_root'],'val',dat['n_shot'], dat['n_eval'], transform=transform1)
    test_set = Episode(dat['data_root'],'test',dat['n_shot'], dat['n_eval'], transform=transform1)
    
    train_loader = data.DataLoader(train_set, num_workers=dat['n_workers'])
    validation_loader = data.DataLoader(validation_set, num_workers=2)
    test_loader= data.DataLoader(test_set, num_workers=2)
    return train_loader, validation_loader, test_loader


In [None]:
def create_logger(dat):
    mode = dat['mode']
    save_root = dat['save']
    log_freq = dat['log_freq']

    if mode == 'train':
        if not os.path.exists(save_root):
            os.mkdir(save_root)
        filename = os.path.join(save_root, 'console.log')
        logging.basicConfig(level=logging.DEBUG,
            format='%(asctime)s.%(msecs)03d - %(message)s',
            datefmt='%b-%d %H:%M:%S',
            filename=filename,
            filemode='w')
        console = logging.StreamHandler()
        console.setLevel(logging.INFO)
        console.setFormatter(logging.Formatter('%(message)s'))
        logging.getLogger('').addHandler(console)

        logging.info("Logger created at {}".format(filename))
    else:
        logging.basicConfig(level=logging.INFO,
            format='%(asctime)s.%(msecs)03d - %(message)s',
            datefmt='%b-%d %H:%M:%S')

    logging.info("Random Seed: {}".format(dat['seed']))

    stats = {'train': {'loss': [], 'acc': []}} if mode == 'train' else {'eval': {'loss': [], 'acc': []}}

    return stats

def reset_stats(stats):
    if 'train' in stats:
        stats['train']['loss'] = []
        stats['train']['acc'] = []
    if 'eval' in stats:
        stats['eval']['loss'] = []
        stats['eval']['acc'] = []

def log_batch_info(stats, kwargs, log_freq):
    if kwargs['phase'] == 'train':
        stats['train']['loss'].append(kwargs['loss'])
        stats['train']['acc'].append(kwargs['acc'])

        if kwargs['eps'] % log_freq == 0 and kwargs['eps'] != 0:
            loss_mean = np.mean(stats['train']['loss'])
            acc_mean = np.mean(stats['train']['acc'])
            log_info("[{:5d}/{:5d}] loss: {:6.4f} ({:6.4f}), acc: {:6.3f}% ({:6.3f}%)".format(\
                kwargs['eps'], dat['episode_val'], kwargs['loss'], loss_mean, kwargs['acc'], acc_mean))

    elif kwargs['phase'] == 'eval':
        stats['eval']['loss'].append(kwargs['loss'])
        stats['eval']['acc'].append(kwargs['acc'])

    elif kwargs['phase'] == 'evaldone':
        loss_mean = np.mean(stats['eval']['loss'])
        loss_std = np.std(stats['eval']['loss'])
        acc_mean = np.mean(stats['eval']['acc'])
        acc_std = np.std(stats['eval']['acc'])
        log_info("[{:5d}] Eval ({:3d} episode) - loss: {:6.4f} +- {:6.4f}, acc: {:6.3f} +- {:5.3f}%")

In [None]:
def test_meta(eps, data_loader, model_w_grad, model_wo_grad, metalearner, config, metrics):
    for i, (x, y) in enumerate(tqdm(data_loader, ascii=True)):
        x_train = x[:, :config.n_shot].reshape(-1, *x.shape[-3:]).to(config.device)
        y_train = torch.LongTensor(np.repeat(range(config.n_class), config.n_shot)).to(config.device)
        x_test = x[:, config.n_shot:].reshape(-1, *x.shape[-3:]).to(config.device)
        y_test = torch.LongTensor(np.repeat(range(config.n_class), config.n_eval)).to(config.device)

        model_w_grad.reset_batch_stats()
        model_wo_grad.reset_batch_stats()
        model_w_grad.train()
        model_wo_grad.eval()
        cI = train_model(model_w_grad, metalearner, x_train, y_train, config)

        model_wo_grad.transfer_params(model_w_grad, cI)
        output = model_wo_grad(x_test)
        loss = model_wo_grad.criterion(output, y_test)
        acc = accuracy(output, y_test)

        metrics.update(loss=loss.item(), acc=acc, phase='eval')

    return metrics.update(eps=eps, total_eps=config.episode_val, phase='eval_done')


In [None]:
class Learner(nn.Module):
    def __init__(self, image_size, eps, momentum, n_classes):
        self.model = nn.ModuleDict({'features': nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False)),
            ('bn1', nn.BatchNorm2d(64, momentum=momentum)),
            ('relu1', nn.ReLU(inplace=True)),
            ('pool1', nn.MaxPool2d(2, 2)),
            ('conv2', nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False)),
            ('bn2', nn.BatchNorm2d(64, momentum=momentum)),
            ('relu2', nn.ReLU(inplace=True)),
            ('pool2', nn.MaxPool2d(2, 2)),
            ('conv3', nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False)),
            ('bn3', nn.BatchNorm2d(64, momentum=momentum)),
            ('relu3', nn.ReLU(inplace=True)),
            ('pool3', nn.MaxPool2d(2, 2)),
            ('conv4', nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False)),
            ('bn4', nn.BatchNorm2d(64, momentum=momentum)),
            ('relu4', nn.ReLU(inplace=True)),
            ('pool4', nn.MaxPool2d(2, 2))
        ]))})
        clr_in = image_size // 2**4
        self.model.update({'cls': nn.Linear(32 * clr_in * clr_in, n_classes)})
        self.criterion = nn.CrossEntropyLoss()

class MetaLearner(nn.Module):
    def __init__(self, input_size, hidden_size, ):
        self.lstm = nn.LSTMCell(input_size, hidden_size)
        self.metalstm = nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size)

        def forward(self, inputs):
            loss, grad_p, grad = inputs
            loss = loss.expand(grad_p)
            inputs = torch.cat((loss, grad_p),1)

            lstm_h, lstm_c = self.lstm(inputs)
            flat_learner, metalstm_h = self.metalstm([lstm_h, grad],0)
            return flat_learner, [(lstm_h, lstm_c), metalstm]

In [None]:
def learn_train(learner, metalearner, input, target, dat):
    ci = metalearner.metalstm.ci.data

    for _ in range(dat['epoch']):
        for i in range(0, len(input), dat['batch_size']):
            start_idx = i
            end_idx = min(len(input), i + dat['batch_size'])
            x = input[start_idx:end_idx]
            y = target[start_idx:end_idx]

            output = learner(x)
            loss = learner.crtierion(output,y)
            acc = accuracy(output, y)
            loss.backward()
            grad = torch.cat([p.grad.data.view(-1) / dat['batch_size'] for p in learner.parameters()], 0)

            grad_prep = preprocess_grad(grad)
            loss_prep = preprocess_loss(loss)
            ci, h = metalearner(metalearner_input)
    return ci

In [None]:
# Initialize values
dat = []
dat['n_shot'] = 5
dat['n_eval'] = 15
dat['n_class'] = 10
dat['input_size'] = 4
dat['hidden_size'] = 20
dat['lr'] = 1e-3
dat['episode'] = 10000
dat['epoch'] = 8 
dat['batch_size'] = 25
dat['image_size'] = 32
dat['grad_clip'] = 0.25
dat['momentum'] = 0.95
dat['eps'] = 1e-3

# Load data
dat['data_root'] = '/kaggle/input/cifar10_dat/cifar10/test/'

loss_list = [] 
logs = create_logger(dat)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
random.seed(dat['seed'])
np.random.seed(dat['seed'])
torch.manual_seed(dat['seed'])

# learner & meta learner
learner = Learner(dat['input_size'], dat['hidden_size'], dat['n_class']).to(device)
meta_learner = MetaLearner(learner, dat['lr'], dat['grad_clip'], dat['momentum'], dat['eps']).to(device)

optim = torch.optim.Adam(meta_learner.parameters(),dat['lr'])

best_accuracy = 0.0
logs.loginfo("Begin Training")

# print gpu status and usage
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

# Meta Training

for episode, (epx, epy) in enumerate(learn_train):

    train_target = torch.LongTensor(np.repeat(range(dat['n_class']), dat['n_shot'])).to(device)
    train_input = epx.reshape(-1,*epx.shape).to(device)
    test_target = torch.LongTensor(np.repeat(range(dat['n_class']), dat['n_eval'])).to(device)
    test_input = epy.reshape(-1,*epy.shape).to(device)
    
    # Training with MetaLearner
    learner_grad.reset_batch_stats()
    learner_no_grad.reset_batch_stats()
    learner_grad.train()
    learner_no_grad.train()
    CI = learn_train(learner_grad,metalearner,train_input, train_target, dat)

    # Validation loss training
    learner_no_grad.transfer_params(learner_grad,CI)
    output = learner_no_grad(test_input)
    loss = learner_no_grad.criterion(output,test_target)
    acc_level = accuracy(output,test_target)
    optim.zero_grad()
    loss.backward()
    optim.step()
    logs.batch_info(eps, dat['episode'],loss.item(), acc_level)
    loss_list.append(loss.item())

    # Meta Validation
    if eps % dat['val_freq'] ==0 and eps!=0:
        save(eps, metalearner,optim, dat['save'])
        learner_no_grad.transfer_params(learner_grad,CI)
        acc_level = meta_test(eps,val_loader,learner_grad,learner_no_grad,metalearner,dat)


torch.save(learner_grad.state_dict(),'learner_grad.pt')
torch.save(metalearner.state_dict(),'metalearner.pt')
np.savetxt('training_loss.csv', loss_list, delimiter=',')
torch.save(optim.state_dict(), 'optimizer.pt')