In [12]:
import torch, torchvision
import torchvision.transforms as transforms
from PIL import Image
import json, datetime
from sklearn.metrics import f1_score

CLASS_LABEL_PATH = "../../ADARI/furniture/ADARI_furniture_onehots.json"
IMAGE_FOLDER = "../../ADARI/v2/full"

torch.manual_seed(42)

<torch._C.Generator at 0x102d6a870>

In [13]:
def open_json(path):
    f = open(path) 
    data = json.load(f) 
    f.close()
    return data 

class ADARIMultiHotDataset(torch.utils.data.Dataset):
    def __init__(self, image_folder, class_label_file, image_size):
        super(ADARIMultiHotDataset).__init__()
        
        self.image_size = image_size
        self.image_folder = image_folder
        self.class_label_file = class_label_file
        self.transform = transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ])
        self.im_to_one_hots = open_json(self.class_label_file)
        self.im_names = list(self.im_to_one_hots.keys())
        self.num_classes = len(self.im_to_one_hots[self.im_names[0]])
        
    def __len__(self):
        return len(self.im_names)
        
    def __getitem__(self, idx):
        imname = self.im_names[idx]
        
        img = Image.open(self.image_folder + '/' + imname)
        return self.transform(img), torch.tensor(self.im_to_one_hots[imname]).double()
        

In [14]:
# Load Data

data = ADARIMultiHotDataset(IMAGE_FOLDER, CLASS_LABEL_PATH, 64)
vocab_size = data.num_classes

train_set, test_set = torch.utils.data.random_split(data, 
                                                    [int(.8 * len(data)), len(data) - int(.8 * len(data))])

In [15]:
# Create model
def build_model():
    vgg = torchvision.models.vgg16()
    vgg.classifier[6] = torch.nn.Linear(4096, vocab_size)
    return vgg
vgg = build_model()

In [16]:
# Training Parameters
batch_size = 64
num_workers = 1
lr = 0.001
num_epochs = 100
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")

In [17]:
# Training loop

def one_hot_to_multilabel(y):
    # Assumes y is batched, shape (batch_size, vocab_size)
    yhat = [[] for _ in range(y.shape[0])]
    nonzeros = torch.nonzero(y)
    for x in nonzeros:
        yhat[x[0]].append(x[1])
    for i in range(len(yhat)):
        yhat[i].extend([-1] * (y.shape[1] - len(yhat[i])))
    return torch.tensor(yhat)

def train(model):
    model.train()
    model.to(device)
    criterion = torch.nn.BCEWithLogitsLoss()
    dataloader = torch.utils.data.DataLoader(train_set, 
                                            batch_size=batch_size, 
                                            shuffle=True, 
                                            num_workers=num_workers)
    optimizer = torch.optim.Adam(vgg.parameters(), lr=lr)
    
    
    for epoch in range(num_epochs):
        losses = []
        for im, labels in dataloader:
            im = im.to(device)
            labels = labels.to(device)
            
            l_hat = vgg(im)
            loss = criterion(l_hat, labels)
            loss.backward()
            optimizer.step()
            
            losses.append(loss.item())
        print(f"Avg Loss at Epoch {epoch}: {sum(losses) / len(losses)}")
        
        

In [18]:
model_name = datetime.datetime.now()
try:
    train(vgg)
except KeyboardInterrupt:
    pass
vgg.to('cpu')
torch.save(vgg.state_dict(), f"VGG16_ADARI_{model_name}.pth")

In [None]:
# For testing
"""
im = Image.open(IMAGE_FOLDER + '/' + "0a2e5ec5079d9424e239d3dc639f7e1d20c6fba9.jpg")
im = transforms.Compose([transforms.Resize(64),
                               transforms.CenterCrop(64),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ])(im)
print(vgg(im.reshape(1, im.shape[0], im.shape[1], im.shape[2])).shape)


test_d = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)
for im, l in test_d:
    criterion = torch.nn.MultiLabelMarginLoss()
    l = one_hot_to_multilabel(l)
    print(l)
    out = vgg(im)
    print(out.shape)
    print(criterion(out, l))

"""

In [9]:
# Compute Test Accuracy
def test():
    vgg.eval()
    vgg.to(device)
    test_d = torch.utils.data.DataLoader(test_set, batch_size=len(test_set), shuffle=False)
    criterion = torch.nn.BCEWithLogitsLoss()
    for im, l in test_d:
        l = l.to(device)
        imhat = vgg(im.to(device))
        print(f"Test MultiLabel Soft Margin Loss: {criterion(imhat, l)}")
    

In [10]:
#test()

KeyboardInterrupt: 

In [None]:
# Compute F1 Score
def test_score(model, test_set):
    model.eval()
    model.to(device)
    test_d = torch.utils.data.DataLoader(test_set, batch_size=len(test_set), shuffle=False)
    for im, l in test_d:
        im = im.to(device)
        imhat = model(im)
        imhat.to('cpu')
        score = f1_score(l, imhat, average='weighted')
        print(f"F1 Score: {score}")
        
#test_score(vgg, test_set)

In [11]:
#vgg

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1