In [None]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import os
from torchvision.io import read_image
import matplotlib.pyplot as plt
from torchvision import models
from tqdm.notebook import tqdm

In [None]:
class AdversarialDataset(Dataset):
    def __init__(self, annotation_file, categories_file, img_dir, x_transform=None, y_transform=None):
        self.img_dir = img_dir
        annotations = pd.read_csv(annotation_file)
        self.categories = pd.read_csv(categories_file)
        self.images = annotations["ImageId"] + ".png"
        self.labels = annotations["TrueLabel"]
        self.x_transform = x_transform
        self.y_transform = y_transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.images[idx])
        image = read_image(img_path)
        label = self.categories.loc[self.categories["CategoryId"] == self.labels[idx], "CategoryName"].values[0]

        if self.x_transform:
            image = self.x_transform(image)

        if self.y_transform:
            image = self.y_transform(label)

        return image, self.labels[idx]

In [None]:
def show_image(datarow):
    permuted = torch.permute(datarow[0], (1, 2, 0))
    plt.title(datarow[1])
    plt.imshow(permuted)

In [None]:
transform = transforms.Compose([
    transforms.ToTensor()
])

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

In [None]:
dataset = AdversarialDataset("images.csv", "categories.csv", "images")
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

model = models.inception_v3(weights=None)

loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
running_loss = 0
n_correct = 0

for x_batch, y_batch in tqdm(dataloader):
    x_batch = x_batch.float()
    x_batch = x_batch.to(device)
    y_batch = y_batch.type(torch.LongTensor)
    y_batch = y_batch.to(device)

    output, _ = model(x_batch)

    loss = loss_function(output, y_batch)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    running_loss += loss.item() + x_batch.size(0)

    _, preds = torch.max(output, dim=1)
    n_correct += torch.sum(preds == y_batch.data)

print(running_loss, n_correct)