In [None]:
import numpy as np
import torch
import torchvision
import matplotlib.pyplot as plt
from time import time
from torchvision import datasets, transforms
import torchvision.models as models
from torch import nn, optim
from torch.nn.modules.loss import *
from Loss.triplet import *
from session import *
from LR_Schedule.cyclical import Cyclical
from LR_Schedule.cos_anneal import CosAnneal
from LR_Schedule.lr_find import lr_find
from callbacks import *
from validation import *
from validation import _AccuracyMeter
import Datasets.ImageData as ImageData
from Transforms.ImageTransforms import *
import util
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.decomposition import PCA
from torch.utils.tensorboard import SummaryWriter
%matplotlib notebook

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
writer = SummaryWriter()

In [None]:
torch.cuda.set_device(0); torch.backends.cudnn.benchmark=True; torch.cuda.is_available(); 
torch.cuda.get_device_name(0)

In [None]:
transform = transforms.Compose([
                                transforms.RandomRotation(9),
                                transforms.RandomResizedCrop(32, scale=(.95, 1.05)),
                                transforms.RandomHorizontalFlip(),
                                transforms.ToTensor(), 
                                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))])

trainset = torchvision.datasets.CIFAR10(root='/media/drake/MX500/Datasets/cifar-10/train', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=16,
                                          shuffle=True, num_workers=2)

valset = torchvision.datasets.CIFAR10(root='/media/drake/MX500/Datasets/cifar-10/test', train=False,
                                       download=True, transform=transform)
valloader = torch.utils.data.DataLoader(valset, batch_size=16,
                                         shuffle=False, num_workers=2)

In [None]:
class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

In [None]:
model = models.resnet18(pretrained=False)
num_ftrs = model.fc.in_features
model.fc = nn.Sequential(
    nn.Dropout(p=.2),
    nn.Linear(num_ftrs, 64)
)
model

In [None]:
criterion = BatchAllTripletLoss(1)
optim_fn = optim.AdamW

In [None]:
sess = Session(model, criterion, optim_fn, 1e-3)

In [None]:
lr_find(sess, trainloader, start_lr=1e-7)

In [None]:
sess.set_lr(1e-4)

In [None]:
schedule = TrainingSchedule(trainloader, [LossLogger(), CosAnneal(len(trainloader), T_mult=2)])

In [None]:
sess.train(schedule, 63)

In [None]:
with EvalModel(model):
    outputs = []
    for input, label in valloader:
        outputs.append(model.forward(Variable(util.to_gpu(input))).data.cpu())
     

In [None]:
preds = torch.cat(outputs)
preds.shape

In [None]:
valset.data.size(0)

In [None]:
writer.add_embedding(preds, metadata=valset.targets)

In [None]:
pca = PCA(n_components=3)
reduced = pca.fit_transform(preds)

In [None]:
dir(valset)
valset.targets

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(reduced[:,0], reduced[:,1], reduced[:,2], c=valset.targets)