In [1]:
import os
import glob
import matplotlib.pyplot as plt
import numpy as np
import argparse
import torch
import torch.nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR

from src.dataset import CUB as Dataset
from src.sampler import Sampler
from src.train_sampler import Train_Sampler
from src.utils import count_acc, Averager, csv_write, square_euclidean_metric
from model import FewShotModel

from src.test_dataset import CUB as Test_Dataset
from src.test_sampler import Test_Sampler

import torchvision
from torch.utils.tensorboard import SummaryWriter


" User input value "
TOTAL = 99000  # total step of training
PRINT_FREQ = 10  # frequency of print loss and accuracy at training step
VAL_FREQ = 250  # frequency of model eval on validation dataset
SAVE_FREQ = 250  # frequency of saving model
TEST_SIZE = 200  # fixed

" fixed value "
VAL_TOTAL = 100


In [2]:
def weights_init(m):
    if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear):
        torch.nn.init.kaiming_normal_(m.weight)
        torch.nn.init.zeros_(m.bias)
        
def Test_phase(model, args, k):
    model.eval()

    csv = csv_write(args)

    dataset = Test_Dataset(args.dpath)
    test_sampler = Test_Sampler(dataset._labels, n_way=args.nway, k_shot=args.kshot, query=args.query)
    test_loader = DataLoader(dataset=dataset, batch_sampler=test_sampler, num_workers=4, pin_memory=True)

    print('Test start!')
    for i in range(TEST_SIZE):
        for episode in test_loader:
            data = episode.cuda()

            data_shot, data_query = data[:k], data[k:]

            """ TEST Method """
            """ Predict the query images belong to which classes
            
            At the training phase, you measured logits. 
            The logits can be distance or similarity between query images and 5 images of each classes.
            From logits, you can pick a one class that have most low distance or high similarity.
            
            ex) # when logits is distance
                pred = torch.argmin(logits, dim=1)
            
                # when logits is prob
                pred = torch.argmax(logits, dim=1)
                
            pred is torch.tensor with size [20] and the each component value is zero to four
            """
            emb_shot = model(data_shot)
            emb_query = model(data_query)

            kshot = args.kshot
            nway = args.nway
            emb_dim = emb_shot.shape[-1]
            n_query = emb_query.size()[0]

            prototypes = torch.zeros((args.nway, emb_dim), requires_grad=False).cuda()
            for nw in range(nway):
                proto = emb_shot[nw*kshot :(nw+1)*kshot]
                proto = torch.sum(proto, dim=0)/kshot
                prototypes[nw,:] = proto

            def _euc_dist(pro,q):
                a = pro.shape[0]
                logit = pro - q.unsqueeze(0).expand(a,-1)
                logit = (logit**2).mean(1)
                return logit.squeeze()
            
            pred = torch.zeros((n_query), requires_grad=False).cuda()
            for q in range(n_query):
                logit = _euc_dist(prototypes, emb_query[q])
                pred[q] = torch.argmin(logit)
                

            # save your prediction as StudentID_Name.csv file
            csv.add(pred)

    csv.close()
    print('Test finished, check the csv file!')
    exit()

def train(args):
    # tnesorboard writer
    writer = SummaryWriter(args.writer_path)
    # the number of N way, K shot images
    k = args.nway * args.kshot

    # Train data loading
    dataset = Dataset(args.dpath, state='train')
    train_sampler = Train_Sampler(dataset._labels, n_way=args.nway, k_shot=args.kshot, query=args.query)
    data_loader = DataLoader(dataset=dataset, batch_sampler=train_sampler, num_workers=4, pin_memory=True)

    # Validation data loading
    val_dataset = Dataset(args.dpath, state='val')
    val_sampler = Sampler(val_dataset._labels, n_way=args.nway, k_shot=args.kshot, query=args.query)
    val_data_loader = DataLoader(dataset=val_dataset, batch_sampler=val_sampler, num_workers=4, pin_memory=True)

    """ TODO 1.a """
    " Make your own model for Few-shot Classification in 'model.py' file."
    model = FewShotModel()
    model.apply(weights_init)
    model.cuda()
    """ TODO 1.a END """
    
    
    """ TODO 1.b (optional) """
    " Set an optimizer or scheduler for Few-shot classification (optional) "

    # Default optimizer setting
    optimizer = torch.optim.Adam(model.parameters(), lr=0.002)
    scheduler = StepLR(optimizer,step_size=5000,gamma=0.5)
    """ TODO 1.b (optional) END """
    
    epoch_range = range(TOTAL)
    
    # pretrained model load
    if args.restore_ckpt:
        last_ckpt = sorted(glob.glob(args.ckpt_path+'/*.pth'))[-1]
        ckpt = torch.load(last_ckpt)
        
        model.load_state_dict(ckpt['state_dict'])
        model.cuda()
        optimizer.load_state_dict(ckpt['optimizer'])
        scheduler = ckpt['scheduler']
        start_epoch = ckpt['epoch']
        
        epoch_range = range(start_epoch, TOTAL)
    
    print('restore ckpt', args.restore_ckpt)
    print(epoch_range)

    if args.test_mode == 1:
        Test_phase(model, args, k)
    
    model.train()
    tl = Averager()  # save average loss
    ta = Averager()  # save average accuracy
    # training start
    print('train start')
    for i in epoch_range:
        #scheduler.step(i)
        for episode in data_loader:
            optimizer.zero_grad()
            if torch.cuda.is_available():
                data, label = [_.cuda() for _ in episode]  # load an episode
            # split an episode images and labels into shots and query set
            # note! data_shot shape is ( nway * kshot, 3, h, w ) not ( kshot * nway, 3, h, w )
            # Take care when reshape the data shot
            data_shot, data_query = data[:k], data[k:]

            label_shot, label_query = label[:k], label[k:]
            label_shot = sorted(list(set(label_shot.tolist())))

            # convert labels into 0-4 values
            label_query = label_query.tolist()
            labels = []
            for j in range(len(label_query)):
                label = label_shot.index(label_query[j])
                labels.append(label)
            labels = torch.tensor(labels).cuda()
            
            if i==0 :
                img_grid = torchvision.utils.make_grid(data_shot)
                writer.add_image('data_shot', img_grid)
                
                
            """ TODO 2 ( Same as above TODO 2 ) """
            """ Train the model 
            Input:
                data_shot : torch.tensor, shot images, [args.nway * args.kshot, 3, h, w]
                            be careful when using torch.reshape or .view functions
                data_query : torch.tensor, query images, [args.query, 3, h, w]
                labels : torch.tensor, labels of query images, [args.query]
            output:
                loss : torch scalar tensor which used for updating your model
                logits : A value to measure accuracy and loss
            """
            emb_shot = model(data_shot)
            emb_query = model(data_query)

            kshot = args.kshot
            nway = args.nway
            emb_dim = emb_shot.shape[-1]
            n_query = labels.size()[0]

            prototypes = torch.zeros((args.nway, emb_dim), requires_grad=False).cuda()
            for nw in range(nway):
                proto = emb_shot[nw*kshot :(nw+1)*kshot]
                proto = torch.sum(proto, dim=0)/kshot
                prototypes[nw,:] = proto

            def _euc_dist(pro,q):
                a = pro.shape[0]
                logit = pro - q.unsqueeze(0).expand(a,-1)
                logit = (logit**2).mean(1)
                return logit.squeeze()
            
            logits = torch.zeros((n_query, nway), requires_grad=False).cuda()
            for q in range(n_query):
                logit = _euc_dist(prototypes, emb_query[q])
                logits[q,:] = logit

            loss = F.cross_entropy(-logits, labels)
            

            """ TODO 2 END """
            
           
            acc = count_acc(logits, labels)

            tl.add(loss.item())
            ta.add(acc)
            loss.backward()
            optimizer.step()
            #if (i+1) % PRINT_FREQ == 0:
                #print(torch.argmin(logits, dim=1))
            
            proto = None; logits = None; loss = None
            
        
        if (i+1) % PRINT_FREQ == 0:
            writer.add_scalar('train loss', tl.item(), i+1)
            writer.add_scalar('train acc', ta.item(), i+1)
            
            print('train {}, loss={:.4f} acc={:.4f}'.format(i+1, tl.item(), ta.item()))
            # initialize loss and accuracy mean
            tl = None
            ta = None
            tl = Averager()
            ta = Averager()

        # validation start
        if (i+1) % VAL_FREQ == 0:
            print('validation start')
            model.eval()
            with torch.no_grad():
                vl = Averager()  # save average loss
                va = Averager()  # save average accuracy
                for j in range(VAL_TOTAL):
                    for episode in val_data_loader:
                        data, label = [_.cuda() for _ in episode]

                        data_shot, data_query = data[:k], data[k:] # load an episode

                        label_shot, label_query = label[:k], label[k:]
                        label_shot = sorted(list(set(label_shot.tolist())))

                        label_query = label_query.tolist()

                        labels = []
                        for j in range(len(label_query)):
                            label = label_shot.index(label_query[j])
                            labels.append(label)
                        labels = torch.tensor(labels).cuda()

                        """ TODO 2 ( Same as above TODO 2 ) """
                        """ Train the model 
                        Input:
                            data_shot : torch.tensor, shot images, [args.nway * args.kshot, 3, h, w]
                                        be careful when using torch.reshape or .view functions
                            data_query : torch.tensor, query images, [args.query, 3, h, w]
                            labels : torch.tensor, labels of query images, [args.query]
                        output:
                            loss : torch scalar tensor which used for updating your model
                            logits : A value to measure accuracy and loss
                        """
                        emb_shot = model(data_shot)
                        emb_query = model(data_query)

                        kshot = args.kshot
                        nway = args.nway
                        emb_dim = emb_shot.shape[-1]
                        n_query = labels.size()[0]

                        prototypes = torch.zeros((args.nway, emb_dim), requires_grad=True).cuda()
                        for nw in range(nway):
                            proto = emb_shot[nw*kshot :(nw+1)*kshot]
                            proto = torch.sum(proto, dim=0)/kshot
                            prototypes[nw,:] = proto

                        def _euc_dist(pro,q):
                            a = pro.shape[0]
                            logit = pro - q.unsqueeze(0).expand(a,-1)
                            logit = (logit**2).mean(1)
                            return logit.squeeze()

                        logits = torch.zeros((n_query, nway), requires_grad=True).cuda()
                        for q in range(n_query):
                            logit = _euc_dist(prototypes, emb_query[q])
                            logits[q,:] = logit

                        loss = F.cross_entropy(-logits, labels)


                        """ TODO 2 END """

                        acc = count_acc(logits,labels)

                        vl.add(loss.item())
                        va.add(acc)

                        proto = None; logits = None; loss = None
                
                writer.add_scalar('val loss', vl.item(), i+1)
                writer.add_scalar('val acc',  va.item(), i+1)

                print('val accuracy mean : %.4f' % va.item())
                print('val loss mean : %.4f' % vl.item())

                # initialize loss and accuracy mean
                vl = None
                va = None
                vl = Averager()
                va = Averager()
            model.train()

        if (i+1) % SAVE_FREQ == 0:
            PATH =  args.ckpt_path + '/%06d_%s.pth' % (i + 1, args.name)
            state = {'epoch': i , 'state_dict': model.state_dict(),
             'optimizer': optimizer.state_dict(), 'scheduler': scheduler}
            torch.save(state, PATH)
            print('model saved, iteration : %d' % i)


In [3]:
class args():
    def __init__(self):
        self.name = 'model'
        self.dpath = './dataset/CUB_200_2011/CUB_200_2011'
        self.restore_ckpt = True
        self.nway = 5
        self.kshot = 5
        self.query = 20
        self.ntest = 100
        self.gpus = 0
        self.exp_name = 'exp7'
        self.ckpt_path = 'checkpoints/' + self.exp_name 
        self.writer_path = 'runs/' + self.exp_name 
        self.ntest = 200
        self.test_mode = 0
        
    def set_exp_name(self, name) : 
        self.exp_name = name
        self.ckpt_path = 'checkpoints/' + name
        self.writer_path = 'runs/' + name
        
args = args()
    
exp_name = input("Experiment name : ")
args.set_exp_name(exp_name)
is_test = input("test? 1/0 : ")
args.test_mode = int(is_test)
if is_test:
    args.ntest = input("number of test(defalut-200): ")
    args.restore_ckpt = True
else:        
    is_restore = input("restore ? 1/0 : ")
    args.restore_ckpt = int(is_restore)

if not os.path.isdir(args.ckpt_path):
    os.mkdir(args.ckpt_path)
    
torch.cuda.set_device(args.gpus)
train(args)
    

Experiment name : exp8
test? 1/0 : 1
number of test(defalut-200): 200
restore ckpt True
range(74999, 99000)
Test start!


  pos_ss, pos_q = torch.tensor(pos[:self.k_shot], dtype=torch.long), torch.tensor(pos[self.k_shot:], dtype=torch.long)


torch.Size([25, 3, 224, 224])
tensor(3, device='cuda:0')
tensor(1, device='cuda:0')
tensor(4, device='cuda:0')
tensor(1, device='cuda:0')
tensor(0, device='cuda:0')
tensor(2, device='cuda:0')
tensor(4, device='cuda:0')
tensor(2, device='cuda:0')
tensor(3, device='cuda:0')
tensor(0, device='cuda:0')
tensor(2, device='cuda:0')
tensor(4, device='cuda:0')
tensor(3, device='cuda:0')
tensor(3, device='cuda:0')
tensor(0, device='cuda:0')
tensor(2, device='cuda:0')
tensor(2, device='cuda:0')
tensor(0, device='cuda:0')
tensor(1, device='cuda:0')
tensor(2, device='cuda:0')
torch.Size([25, 3, 224, 224])
torch.Size([25, 3, 224, 224])
torch.Size([25, 3, 224, 224])
torch.Size([25, 3, 224, 224])
torch.Size([25, 3, 224, 224])
torch.Size([25, 3, 224, 224])
torch.Size([25, 3, 224, 224])
torch.Size([25, 3, 224, 224])
torch.Size([25, 3, 224, 224])
torch.Size([25, 3, 224, 224])
torch.Size([25, 3, 224, 224])
torch.Size([25, 3, 224, 224])
torch.Size([25, 3, 224, 224])
torch.Size([25, 3, 224, 224])
torch.Size

KeyboardInterrupt: 

## train dataset mean, std

In [None]:
dataset = Dataset(args.dpath, state='train')
loader = DataLoader(
    dataset,
    batch_size=10,
    num_workers=1,
    shuffle=False
)

mean_t = 0.
std_t = 0.
nb_samples_t = 0.

for data in loader:
    batch_samples = data[0].size(0)
    data = data[0].view(batch_samples, data[0].size(1), -1)
    mean_t += data.mean(2).sum(0)
    std_t += data.std(2).sum(0)
    nb_samples_t += batch_samples

mean_t /= nb_samples_t
std_t /= nb_samples_t
print(mean_t, std_t, nb_samples_t)

## val dataset mean, std

In [None]:
dataset = Dataset(args.dpath, state='val')
loader = DataLoader(
    dataset,
    batch_size=10,
    num_workers=1,
    shuffle=False
)

mean_v = 0.
std_v = 0.
nb_samples_v= 0.

for data in loader:
    batch_samples = data[0].size(0)
    data = data[0].view(batch_samples, data[0].size(1), -1)
    mean_v += data.mean(2).sum(0)
    std_v += data.std(2).sum(0)
    nb_samples_v += batch_samples

mean_v /= nb_samples_v
std_v /= nb_samples_v
print(mean_v, std_v, nb_samples_v)

In [None]:
print(((mean_v * nb_samples_v) + (mean_t * nb_samples_t))/(nb_samples_t+nb_samples_v),
((std_v * nb_samples_v) + (std_t * nb_samples_t))/(nb_samples_t+nb_samples_v))