In [2]:
import os
import glob
import argparse
import torch
import torch.nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR

from src.dataset import CUB as Dataset
from src.sampler import Sampler
from src.train_sampler import Train_Sampler
from src.utils import count_acc, Averager
from model import FewShotModel

import torchvision
from torch.utils.tensorboard import SummaryWriter

import matplotlib.pyplot as plt
import numpy as np

" User input value "
TOTAL = 99000  # total step of training
PRINT_FREQ = 10  # frequency of print loss and accuracy at training step
VAL_FREQ = 250  # frequency of model eval on validation dataset
SAVE_FREQ = 250  # frequency of saving model
TEST_SIZE = 200  # fixed

" fixed value "
VAL_TOTAL = 100


In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.kaiming_normal_(m.weight)
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)
    elif classname.find('BatchNorm') != -1:
        torch.nn.init.normal_(m.weight)
        torch.nn.init.zeros_(m.bias)
    elif classname.find('Linear') != -1:
        m.weight.data.normal_(0, 0.01)
        torch.nn.init.ones_(m.bias)
        


def train(args):
    # tnesorboard writer
    writer = SummaryWriter(args.writer_path)
    
    # the number of N way, K shot images
    k = args.nway * args.kshot

    # Train data loading
    dataset = Dataset(args.dpath, state='train')
    train_sampler = Train_Sampler(dataset._labels, n_way=args.nway, k_shot=args.kshot, query=args.query)
    data_loader = DataLoader(dataset=dataset, batch_sampler=train_sampler, num_workers=4, pin_memory=True)

    # Validation data loading
    val_dataset = Dataset(args.dpath, state='val')
    val_sampler = Sampler(val_dataset._labels, n_way=args.nway, k_shot=args.kshot, query=args.query)
    val_data_loader = DataLoader(dataset=val_dataset, batch_sampler=val_sampler, num_workers=4, pin_memory=True)

    """ TODO 1.a """
    " Make your own model for Few-shot Classification in 'model.py' file."
    model = FewShotModel()
    model.apply(weights_init)
    model.cuda()
    """ TODO 1.a END """
    
    
    """ TODO 1.b (optional) """
    " Set an optimizer or scheduler for Few-shot classification (optional) "

    # Default optimizer setting
    optimizer = torch.optim.Adam(model.parameters(), lr=0.002)
    scheduler = StepLR(optimizer,step_size=5000,gamma=0.5)
    """ TODO 1.b (optional) END """
    
    epoch_range = range(TOTAL)
    
    # pretrained model load
    if args.restore_ckpt:
        last_ckpt = sorted(glob.glob(args.ckpt_path+'/*.pth'))[-1]
        ckpt = torch.load(last_ckpt)
        
        model.load_state_dict(ckpt['state_dict'])
        model.cuda()
        optimizer.load_state_dict(ckpt['optimizer'])
        scheduler = ckpt['scheduler']
        start_epoch = ckpt['epoch']
        
        epoch_range = range(start_epoch, TOTAL)
    
       
    print('restore ckpt', args.restore_ckpt)
    print('start epoch: ', start_epoch)
    
       
    model.train()
    tl = Averager()  # save average loss
    ta = Averager()  # save average accuracy
    # training start
    print('train start')
    for i in epoch_range:
        #scheduler.step(i)
        for episode in data_loader:
            optimizer.zero_grad()
            if torch.cuda.is_available():
                data, label = [_.cuda() for _ in episode]  # load an episode
            # split an episode images and labels into shots and query set
            # note! data_shot shape is ( nway * kshot, 3, h, w ) not ( kshot * nway, 3, h, w )
            # Take care when reshape the data shot
            data_shot, data_query = data[:k], data[k:]

            label_shot, label_query = label[:k], label[k:]
            label_shot = sorted(list(set(label_shot.tolist())))

            # convert labels into 0-4 values
            label_query = label_query.tolist()
            labels = []
            for j in range(len(label_query)):
                label = label_shot.index(label_query[j])
                labels.append(label)
            labels = torch.tensor(labels).cuda()
            
            if i==0 :
                img_grid = torchvision.utils.make_grid(data_shot)
                writer.add_image('data_shot', img_grid)
                
                
            """ TODO 2 ( Same as above TODO 2 ) """
            """ Train the model 
            Input:
                data_shot : torch.tensor, shot images, [args.nway * args.kshot, 3, h, w]
                            be careful when using torch.reshape or .view functions
                data_query : torch.tensor, query images, [args.query, 3, h, w]
                labels : torch.tensor, labels of query images, [args.query]
            output:
                loss : torch scalar tensor which used for updating your model
                logits : A value to measure accuracy and loss
            """
            emb_shot = model(data_shot)
            emb_query = model(data_query)

            kshot = args.kshot
            nway = args.nway
            emb_dim = emb_shot.shape[-1]
            n_query = labels.size()[0]

            prototypes = torch.zeros((args.nway, emb_dim), requires_grad=True).cuda()
            for nw in range(nway):
                proto = emb_shot[nw*kshot :(nw+1)*kshot]
                proto = torch.sum(proto, dim=0)/kshot
                prototypes[nw,:] = proto

            def _euc_dist(pro,q):
                a = pro.shape[0]
                logit = pro - q.unsqueeze(0).expand(a,-1)
                logit = (logit**2).mean(1)
                return logit.squeeze()
            
            logits = torch.zeros((n_query, nway), requires_grad=True).cuda()
            for q in range(n_query):
                logit = _euc_dist(prototypes, emb_query[q])
                logits[q,:] = logit

            loss = F.cross_entropy(-logits, labels)
            

            """ TODO 2 END """
            
           
            acc = count_acc(logits, labels)

            tl.add(loss.item())
            ta.add(acc)
            loss.backward()
            optimizer.step()
            #if (i+1) % PRINT_FREQ == 0:
                #print(torch.argmin(logits, dim=1))
            
            proto = None; logits = None; loss = None
            
        
        if (i+1) % PRINT_FREQ == 0:
            writer.add_scalar('train loss', tl.item(), i+1)
            writer.add_scalar('train acc', ta.item(), i+1)
            
            print('train {}, loss={:.4f} acc={:.4f}'.format(i+1, tl.item(), ta.item()))
            # initialize loss and accuracy mean
            tl = None
            ta = None
            tl = Averager()
            ta = Averager()

        # validation start
        if (i+1) % VAL_FREQ == 0:
            print('validation start')
            model.eval()
            with torch.no_grad():
                vl = Averager()  # save average loss
                va = Averager()  # save average accuracy
                for j in range(VAL_TOTAL):
                    for episode in val_data_loader:
                        data, label = [_.cuda() for _ in episode]

                        data_shot, data_query = data[:k], data[k:] # load an episode

                        label_shot, label_query = label[:k], label[k:]
                        label_shot = sorted(list(set(label_shot.tolist())))

                        label_query = label_query.tolist()

                        labels = []
                        for j in range(len(label_query)):
                            label = label_shot.index(label_query[j])
                            labels.append(label)
                        labels = torch.tensor(labels).cuda()

                        """ TODO 2 ( Same as above TODO 2 ) """
                        """ Train the model 
                        Input:
                            data_shot : torch.tensor, shot images, [args.nway * args.kshot, 3, h, w]
                                        be careful when using torch.reshape or .view functions
                            data_query : torch.tensor, query images, [args.query, 3, h, w]
                            labels : torch.tensor, labels of query images, [args.query]
                        output:
                            loss : torch scalar tensor which used for updating your model
                            logits : A value to measure accuracy and loss
                        """
                        emb_shot = model(data_shot)
                        emb_query = model(data_query)

                        kshot = args.kshot
                        nway = args.nway
                        emb_dim = emb_shot.shape[-1]
                        n_query = labels.size()[0]

                        prototypes = torch.zeros((args.nway, emb_dim), requires_grad=True).cuda()
                        for nw in range(nway):
                            proto = emb_shot[nw*kshot :(nw+1)*kshot]
                            proto = torch.sum(proto, dim=0)/kshot
                            prototypes[nw,:] = proto

                        def _euc_dist(pro,q):
                            a = pro.shape[0]
                            logit = pro - q.unsqueeze(0).expand(a,-1)
                            logit = (logit**2).mean(1)
                            return logit.squeeze()

                        logits = torch.zeros((n_query, nway), requires_grad=True).cuda()
                        for q in range(n_query):
                            logit = _euc_dist(prototypes, emb_query[q])
                            logits[q,:] = logit

                        loss = F.cross_entropy(-logits, labels)


                        """ TODO 2 END """

                        acc = count_acc(logits,labels)

                        vl.add(loss.item())
                        va.add(acc)

                        proto = None; logits = None; loss = None
                
                writer.add_scalar('val loss', vl.item(), i+1)
                writer.add_scalar('val acc',  va.item(), i+1)

                print('val accuracy mean : %.4f' % va.item())
                print('val loss mean : %.4f' % vl.item())

                # initialize loss and accuracy mean
                vl = None
                va = None
                vl = Averager()
                va = Averager()
            model.train()

        if (i+1) % SAVE_FREQ == 0:
            PATH =  args.ckpt_path + '/%06d_%s.pth' % (i + 1, args.name)
            state = {'epoch': i , 'state_dict': model.state_dict(),
             'optimizer': optimizer.state_dict(), 'scheduler': scheduler}
            torch.save(state, PATH)
            print('model saved, iteration : %d' % i)


In [None]:
class args():
    def __init__(self):
        self.name = 'model'
        self.dpath = './dataset/CUB_200_2011/CUB_200_2011'
        self.restore_ckpt = True
        self.nway = 5
        self.kshot = 5
        self.query = 20
        self.ntest = 100
        self.gpus = 0
        self.exp_name = 'exp7'
        self.ckpt_path = 'checkpoints/' + self.exp_name 
        self.writer_path = 'runs/' + self.exp_name 

    
    def set_exp_name(self, name) : 
        self.exp_name = name
        self.ckpt_path = 'checkpoints/' + name
        self.writer_path = 'runs/' + name
        
args = args()

    
exp_name = input("Experiment name : ")
args.set_exp_name(exp_name)
is_restore = input("restore ? 1/0 : ")
args.restore_ckpt = int(is_restore)

if not os.path.isdir(args.ckpt_path):
    os.mkdir(args.ckpt_path)
    
torch.cuda.set_device(args.gpus)
train(args)

In [None]:
args.ckpt_path

## train dataset mean, std

In [14]:
np.random.rand()*()

0.3882586399353469

In [None]:
dataset = Dataset(args.dpath, state='train')
loader = DataLoader(
    dataset,
    batch_size=10,
    num_workers=1,
    shuffle=False
)

mean_t = 0.
std_t = 0.
nb_samples_t = 0.

for data in loader:
    batch_samples = data[0].size(0)
    data = data[0].view(batch_samples, data[0].size(1), -1)
    mean_t += data.mean(2).sum(0)
    std_t += data.std(2).sum(0)
    nb_samples_t += batch_samples

mean_t /= nb_samples_t
std_t /= nb_samples_t
print(mean_t, std_t, nb_samples_t)

## val dataset mean, std

In [None]:
dataset = Dataset(args.dpath, state='val')
loader = DataLoader(
    dataset,
    batch_size=10,
    num_workers=1,
    shuffle=False
)

mean_v = 0.
std_v = 0.
nb_samples_v= 0.

for data in loader:
    batch_samples = data[0].size(0)
    data = data[0].view(batch_samples, data[0].size(1), -1)
    mean_v += data.mean(2).sum(0)
    std_v += data.std(2).sum(0)
    nb_samples_v += batch_samples

mean_v /= nb_samples_v
std_v /= nb_samples_v
print(mean_v, std_v, nb_samples_v)

In [None]:
print(((mean_v * nb_samples_v) + (mean_t * nb_samples_t))/(nb_samples_t+nb_samples_v),
((std_v * nb_samples_v) + (std_t * nb_samples_t))/(nb_samples_t+nb_samples_v))

In [None]:
def loss_fn(emb_shot, emb_query, labels):
        kshot = args.kshot
        nway = args.nway
        emb_dim = emb_shot.shape[-1]
        n_query = labels.size()[0]

        prototypes = torch.zeros((args.nway, emb_dim), requires_grad=True).cuda()
        for i in range(nway):
            proto = emb_shot[i*kshot :(i+1)*kshot]
            proto = torch.sum(proto, dim=0)/kshot
            prototypes[i,:] = proto

        def _euc_dist(pro,q):
            a = pro.shape[0]
            logit = pro - q.unsqueeze(0).expand(a,-1)
            logit = (logit**2).sum(1)
            return logit.squeeze()


        logits = torch.zeros((n_query, nway), requires_grad=True).cuda()
        for i in range(n_query):
            logit = _euc_dist(prototypes, emb_query[i])
            logits[i,:] = logit

        loss = F.cross_entropy(-logits, labels)
        return loss, logits

In [None]:
def loss_fn(emb_shot, emb_query, labels):
    prototypes = []
    kshot = args.kshot
    for i in range(args.nway):
        proto = emb_shot[i*kshot :(i+1)*kshot]
        proto = torch.sum(proto, dim=0)/kshot
        prototypes.append(proto.tolist())
    prototypes = torch.tensor(prototypes).cuda()

    def _euc_dist(pro,q):
        a = pro.shape[0]
        logits = pro - q.unsqueeze(0).expand(a,-1)
        logits = (logits**2).sum(1)
        return logits.squeeze()

    def _one_hot(labels):
        one_hot = torch.zeros(labels.size()[0], args.nway)
        one_hot[torch.arange(labels.size()[0]), labels] =1
        return one_hot

    loss = 0
    logits = []
    n_query = labels.size()[0]
    one_hot = _one_hot(labels)
    for i in range(n_query):
        logit = _euc_dist(prototypes, emb_query[i]).cuda()
        logits.append(logit)
        loss = loss + F.cross_entropy(logit.transpose(0,1), one_hot[i]).item()

    loss = loss/n_query
    logits = torch.tensor(logits).cuda()

    return loss, logits

In [None]:
tl = Averager()  # save average loss
ta = Averager()  # save average accuracy
# training start
print('train start')
for i in range(TOTAL):
    for episode in data_loader:
        optimizer.zero_grad()
        print('iteration', i)
        if torch.cuda.is_available():
            data, label = [_.cuda() for _ in episode]  # load an episode
        # split an episode images and labels into shots and query set
        # note! data_shot shape is ( nway * kshot, 3, h, w ) not ( kshot * nway, 3, h, w )
        # Take care when reshape the data shot
        data_shot, data_query = data[:k], data[k:]

        label_shot, label_query = label[:k], label[k:]
        label_shot = sorted(list(set(label_shot.tolist())))

        # convert labels into 0-4 values
        label_query = label_query.tolist()
        labels = []
        for j in range(len(label_query)):
            label = label_shot.index(label_query[j])
            labels.append(label)
        labels = torch.tensor(labels).cuda()

        """ TODO 2 ( Same as above TODO 2 ) """
        """ Train the model 
        Input:
            data_shot : torch.tensor, shot images, [args.nway * args.kshot, 3, h, w]
                        be careful when using torch.reshape or .view functions
            data_query : torch.tensor, query images, [args.query, 3, h, w]
            labels : torch.tensor, labels of query images, [args.query]
        output:
            loss : torch scalar tensor which used for updating your model
            logits : A value to measure accuracy and loss
        """
        emb_shot = model(data_shot)
        emb_query = model(data_query)
        loss, logit = loss_fn(emb_shot, emb_query, labels)
        """ TODO 2 END """
        acc = count_acc(logit, labels)

        tl.add(loss.item())
        ta.add(acc)

        loss.backward()
        optimizer.step()

        proto = None; logits = None; loss = None

    if (i+1) % PRINT_FREQ == 0:
        print('train {}, loss={:.4f} acc={:.4f}'.format(i+1, tl.item(), ta.item()))

        # initialize loss and accuracy mean
        tl = None
        ta = None
        tl = Averager()
        ta = Averager()

    # validation start
    if (i+1) % VAL_FREQ == 0:
        print('validation start')
        model.eval()
        with torch.no_grad():
            vl = Averager()  # save average loss
            va = Averager()  # save average accuracy
            for j in range(VAL_TOTAL):
                for episode in val_data_loader:
                    data, label = [_.cuda() for _ in episode]

                    data_shot, data_query = data[:k], data[k:] # load an episode

                    label_shot, label_query = label[:k], label[k:]
                    label_shot = sorted(list(set(label_shot.tolist())))

                    label_query = label_query.tolist()

                    labels = []
                    for j in range(len(label_query)):
                        label = label_shot.index(label_query[j])
                        labels.append(label)
                    labels = torch.tensor(labels).cuda()

                    """ TODO 2 ( Same as above TODO 2 ) """
                    """ Train the model 
                    Input:
                        data_shot : torch.tensor, shot images, [args.nway * args.kshot, 3, h, w]
                                    be careful when using torch.reshape or .view functions
                        data_query : torch.tensor, query images, [args.query, 3, h, w]
                        labels : torch.tensor, labels of query images, [args.query]
                    output:
                        loss : torch scalar tensor which used for updating your model
                        logits : A value to measure accuracy and loss
                    """

                    emb_shot = model(data_shot)
                    emb_query = model(data_query)
                    loss, logit = loss_fn(data_shot, data_query, labels)

                    """ TODO 2 END """

                    acc = count_acc(logit,labels)

                    vl.add(loss.item())
                    va.add(acc)

                    proto = None; logits = None; loss = None

            print('val accuracy mean : %.4f' % va.item())
            print('val loss mean : %.4f' % vl.item())

            # initialize loss and accuracy mean
            vl = None
            va = None
            vl = Averager()
            va = Averager()
        model.train()

    if (i+1) % SAVE_FREQ == 0:
        PATH = 'checkpoints/%d_%s.pth' % (i + 1, args.name)
        torch.save(model.state_dict(), PATH)
        print('model saved, iteration : %d' % i)


In [None]:
torch.cat([torch.tensor([1,2,3,4,5])])

In [None]:
torch.tensor([torch.tensor([1,2,3,4,5]).tolist(),torch.tensor([1,2,3,4,6]).tolist()])

parser = argparse.ArgumentParser()
parser.add_argument('--name', default='model', help="name your experiment")
parser.add_argument('--dpath', '--d', default='./dataset/CUB_200_2011/CUB_200_2011', type=str,
                    help='the path where dataset is located')
parser.add_argument('--restore_ckpt', type=str, help="restore checkpoint")
parser.add_argument('--nway', '--n', default=5, type=int, help='number of class in the support set (5 or 20)')
parser.add_argument('--kshot', '--k', default=5, type=int,
                    help='number of data in each class in the support set (1 or 5)')
parser.add_argument('--query', '--q', default=20, type=int, help='number of query data')
parser.add_argument('--ntest', default=100, type=int, help='number of tests')
parser.add_argument('--gpus', type=int, nargs='+', default=1)
args = parser.parse_args()

In [None]:
ddd