In [1]:
import gensim.downloader
import numpy as np
import sys

from torch import utils
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import models, datasets
import torch


In [2]:
# global 

word_vectors = None

In [8]:
def load_data_cifar10(train=True, train_on_embeddings=False):
    transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.491, 0.482, 0.446],
                std= [0.247, 0.243, 0.261]
            )]) # TODO: Automate calculation given a dataset

    dataset = datasets.CIFAR10(root='/nethome/bdevnani3/raid/data', train=train,
                                            download=True, transform=transform)
    if train_on_embeddings:
        dataset = change_target_to_word_vectors(dataset)

    dataloader = utils.data.DataLoader(dataset, batch_size=4,
                                              shuffle=True, num_workers=2)
    return dataloader

def change_target_to_word_vectors(dataset):
    model = 'word2vec-google-news-300'
    global word_vectors
    word_vectors = gensim.downloader.load(model)

    def transform_targets(x):
        return word_vectors[idx_to_class[x]]

    idx_to_class = {y:x for x,y in dataset.class_to_idx.items()}
    dataset.targets = np.array(list(map(transform_targets, dataset.targets)))
    return dataset

trainloader = load_data_cifar10(train=True, train_on_embeddings=True)
testloader = load_data_cifar10(train=False, train_on_embeddings=True)

Files already downloaded and verified
Files already downloaded and verified


In [11]:
model = models.resnet18(pretrained=False)
model.fc = nn.Linear(in_features=512, out_features=300)
model.load_state_dict(torch.load("/nethome/bdevnani3/raid/trained_models/vis_lang/pred_emb.pt"))
model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [12]:
train_loss = 0.0
total  = 0
correct = 0
with torch.no_grad():
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        inputs, labels = inputs.to(device), labels.to(device)

        if torch.cuda.is_available():
            model.cuda()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        train_loss += loss.item()
        
        total += labels.size(0)
        labels, outputs = labels.to("cpu"), outputs.to("cpu")
        for l, o in zip(labels, outputs):
            label_word = word_vectors.similar_by_vector(l.numpy(), topn=1)
            output_word = word_vectors.similar_by_vector(o.data.numpy(), topn=1)
            correct += label_word[0][0] == output_word[0][0]
        if i % 200 == 199:    # print every 200 mini-batches
            print(label_word, output_word, 100.*correct/total)
        
print("Train Loss: {} | Acc: {} | {}/{}".format(train_loss/len(trainloader), 100.*correct/total, correct, total))

[('horse', 1.0)] [('horse', 0.9992008209228516)] 99.875
[('dog', 0.9999999403953552)] [('dog', 0.9927653670310974)] 99.8125
[('frog', 1.0)] [('frog', 0.9986981153488159)] 99.75
[('cat', 1.0)] [('cat', 0.9959831833839417)] 99.78125
[('ship', 1.0000001192092896)] [('ship', 0.9986339807510376)] 99.775


KeyboardInterrupt: 

In [None]:
test_loss = 0.0
total  = 0
correct = 0
min_loss = 0
with torch.no_grad():
    for i, data in enumerate(testloader, 0):
        inputs, labels = data
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        inputs, labels = inputs.to(device), labels.to(device)

        if torch.cuda.is_available():
            model.cuda()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        
        total += labels.size(0)
        labels, outputs = labels.to("cpu"), outputs.to("cpu")
        for l, o in zip(labels, outputs):
            label_word = word_vectors.similar_by_vector(l.numpy(), topn=1)
            output_word = word_vectors.similar_by_vector(o.data.numpy(), topn=1)
            correct += label_word[0][0] == output_word[0][0]
        if i % 200 == 199:    # print every 200 mini-batches
            print(label_word, output_word)
        
        
print("Loss: {} | Acc: {} | {}/{}".format(test_loss/len(testloader), 100.*correct/total, correct, total))