In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import models, datasets
import gensim.downloader
import numpy as np
import sys

In [3]:
# global 

word_vectors = None

In [None]:
def load_data_cifar10(train=True, train_on_embeddings=False):
    transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.491, 0.482, 0.446],
                std= [0.247, 0.243, 0.261]
            )]) # TODO: Automate calculation given a dataset

    dataset = datasets.CIFAR10(root='/nethome/bdevnani3/raid/data', train=train,
                                            download=True, transform=transform)
    if train_on_embeddings:
        dataset = change_target_to_word_vectors(dataset)

    dataloader = torch.utils.data.DataLoader(dataset, batch_size=4,
                                              shuffle=True, num_workers=2)
    return dataloader

def change_target_to_word_vectors(dataset):
    model = 'word2vec-google-news-300'
    global word_vectors
    word_vectors = gensim.downloader.load(model)

    def transform_targets(x):
        return word_vectors[idx_to_class[x]]

    idx_to_class = {y:x for x,y in dataset.class_to_idx.items()}
    dataset.targets = np.array(list(map(transform_targets, dataset.targets)))
    return dataset

trainloader = load_data_cifar10(True, True)
testloader = load_data_cifar10(False, True)

In [5]:
def set_up_model(out_features=10, loss=nn.CrossEntropyLoss()):
    model = models.resnet18(pretrained=False)
    model.fc = nn.Linear(in_features=512, out_features=out_features)
    criterion = loss
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)
    # and a learning rate scheduler
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
    return model, criterion, optimizer, scheduler

model, criterion, optimizer, scheduler = set_up_model(300, nn.MSELoss())

In [None]:
train_loss = 0.0
total  = 0
correct = 0
with torch.no_grad():
    for i, data in enumerate(testloader, 0):
        inputs, labels = data
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        inputs, labels = inputs.to(device), labels.to(device)

        if torch.cuda.is_available():
            model.cuda()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        train_loss += loss.item()
        
        total += labels.size(0)
        labels, outputs = labels.to("cpu"), outputs.to("cpu")
        for l, o in zip(labels, outputs):
            label_word = word_vectors.similar_by_vector(l.numpy(), topn=1)
            output_word = word_vectors.similar_by_vector(o.data.numpy(), topn=1)
            correct += label_word[0][0] == output_word[0][0]
        if i % 200 == 199:    # print every 200 mini-batches
            print(label_word, output_word)
        
print("Train Loss: {} | Acc: {} | {}/{}".format(train_loss/len(testloader), 100.*correct/total, correct, total))

In [None]:
test_loss = 0.0
total  = 0
correct = 0
with torch.no_grad():
    for i, data in enumerate(testloader, 0):
        inputs, labels = data
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        inputs, labels = inputs.to(device), labels.to(device)

        if torch.cuda.is_available():
            model.cuda()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        
        total += labels.size(0)
        labels, outputs = labels.to("cpu"), outputs.to("cpu")
        for l, o in zip(labels, outputs):
            label_word = word_vectors.similar_by_vector(l.numpy(), topn=1)
            output_word = word_vectors.similar_by_vector(o.data.numpy(), topn=1)
            correct += label_word[0][0] == output_word[0][0]
        if i % 200 == 199:    # print every 200 mini-batches
            print(label_word, output_word)
        
print("Loss: {} | Acc: {} | {}/{}".format(test_loss/len(testloader), 100.*correct/total, correct, total))