In [68]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import os
from torchvision.io import read_image
import matplotlib.pyplot as plt
from torchvision import models
from tqdm.notebook import tqdm

In [69]:
class AdversarialDataset(Dataset):
    def __init__(self, annotation_file, categories_file, img_dir, x_transform=None, y_transform=None):
        self.img_dir = img_dir
        annotations = pd.read_csv(annotation_file)
        self.categories = pd.read_csv(categories_file)
        self.images = annotations["ImageId"] + ".png"
        self.labels = annotations["TrueLabel"]
        self.x_transform = x_transform
        self.y_transform = y_transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.images[idx])
        image = read_image(img_path)

        if self.x_transform:
            image = self.x_transform(image)

        return image, self.labels[idx] - 1

In [70]:
df = pd.read_csv("images.csv")

df.head()

Unnamed: 0,ImageId,URL,x1,y1,x2,y2,TrueLabel,TargetClass,OriginalLandingURL,License,Author,AuthorProfileURL
0,0c7ac4a8c9dfa802,https://c1.staticflickr.com/9/8540/28821627444...,0.0,0.0,0.871838,1.0,306,779,https://www.flickr.com/photos/gails_pictures/2...,https://creativecommons.org/licenses/by/2.0/,gailhampshire,https://www.flickr.com/people/gails_pictures/
1,f43fbfe8a9ea876c,https://c1.staticflickr.com/9/8066/28892033183...,0.25,0.0,1.0,0.599758,884,378,https://www.flickr.com/photos/barty/28892033183,https://creativecommons.org/licenses/by/2.0/,Barry Badcock,https://www.flickr.com/people/barty/
2,4fc263d35a3ad3ee,https://c1.staticflickr.com/8/7378/27465801596...,0.333333,0.0,1.0,1.0,244,123,https://www.flickr.com/photos/foxcroftacademy/...,https://creativecommons.org/licenses/by/2.0/,Foxcroft Academy,https://www.flickr.com/people/foxcroftacademy/
3,cc13c2bc5cdd1f44,https://c1.staticflickr.com/9/8864/28546467522...,0.0,0.0,0.5,0.75,560,741,https://www.flickr.com/photos/o_0/28546467522/,https://creativecommons.org/licenses/by/2.0/,Guilhem Vellut,https://www.flickr.com/people/o_0/
4,73a52afd2f818ed5,https://c1.staticflickr.com/6/5607/31066602702...,0.489195,0.0,1.0,0.75,439,696,https://www.flickr.com/photos/chemiebw/3106660...,https://creativecommons.org/licenses/by/2.0/,Chemie-Verb\303\244nde Baden-W\303\274rttemberg,https://www.flickr.com/people/chemiebw/


In [71]:
def show_image(datarow):
    permuted = torch.permute(datarow[0], (1, 2, 0))
    plt.title(datarow[1])
    plt.imshow(permuted)

In [72]:
transform = transforms.Compose([
    transforms.ToTensor()
])

In [73]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [74]:
dataset = AdversarialDataset("images.csv", "categories.csv", "images")
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

model = models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT)
model.to(device)

loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

<h1>Training</h1>

In [75]:
n_epochs = 0
running_loss = 0
n_correct = 0

for epoch in range(n_epochs):
    for x_batch, y_batch in tqdm(dataloader):
        x_batch = x_batch.float()
        x_batch = x_batch.to(device)
        y_batch = y_batch.type(torch.LongTensor)
        y_batch = y_batch.to(device)
    
        output, _ = model(x_batch)
    
        loss = loss_function(output, y_batch)
    
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
        running_loss += loss.item() + x_batch.size(0)
    
        _, preds = torch.max(output, dim=1)
        n_correct += torch.sum(preds == y_batch.data)

<h1>Testing</h1>

In [76]:
n_correct = 0

for x_batch, y_batch in tqdm(dataloader):
    x_batch = x_batch.to(device)
    y_batch = y_batch.to(device)

    output, _ = model(x_batch)

    _, preds = torch.max(output, dim=1)
    n_correct += torch.sum(preds == y_batch)
    
print(n_correct/len(dataset))

  0%|          | 0/63 [00:00<?, ?it/s]

tensor(0.8230, device='cuda:0')
