In [21]:
import sys
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image
import matplotlib.pyplot as plt
import torchvision.transforms.functional as F
from torch.optim import Adam
import tqdm
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image
import matplotlib.pyplot as plt
import torchvision.transforms.functional as F
from torch.optim import Adam
import torcheval.metrics as metrics
import json
import argparse

In [None]:
args = {
    'modelfile':'wikiart.pth',
    'trainingdir':'/home/guserbto@GU.GU.SE/wikiart/train',
    'testingdir': '/home/guserbto@GU.GU.SE/wikiart/test',
    'device': 'cuda:1',
    'epochs':20,
    'batch_size':32,
    
}

modelfile = args['modelfile']
trainingdir = args['trainingdir']
testingdir = args['testingdir']
device = args['device']
epochs = args['epochs']
batch_size = args['batch_size']

In [26]:
class WikiArtImage:
    def __init__(self, imgdir, label, filename):
        self.imgdir = imgdir
        self.label = label
        self.filename = filename
        self.image = None
        self.loaded = False

    def get(self):
        if not self.loaded:
            self.image = read_image(os.path.join(self.imgdir, self.label, self.filename)).float()
            self.loaded = True

        return self.image

In [40]:
class WikiArtDataset(Dataset):
    def __init__(self, imgdir, device="cpu"):
        walking = os.walk(imgdir)
        filedict = {}
        indices = []
        classes = set()
        print("Gathering files for {}".format(imgdir))
        for item in walking:
            sys.stdout.write('.')
            arttype = os.path.basename(item[0])
            artfiles = item[2]
            for art in artfiles:
                filedict[art] = WikiArtImage(imgdir, arttype, art)
                indices.append(art)
                classes.add(arttype)
        print("...finished")
        self.filedict = filedict
        self.imgdir = imgdir
        self.indices = indices
        self.classes = list(classes)
        self.device = device
        
    def __len__(self):
        return len(self.filedict)

    def __getitem__(self, idx):
        imgname = self.indices[idx]
        imgobj = self.filedict[imgname]
        ilabel = self.classes.index(imgobj.label)
        image = imgobj.get().to(self.device)

        return image, ilabel

In [28]:
class WikiArtModel(nn.Module):
    def __init__(self, num_classes=27):
        super().__init__()

        self.conv2d = nn.Conv2d(3, 1, (4,4), padding=2)
        self.maxpool2d = nn.MaxPool2d((4,4), padding=2)
        self.flatten = nn.Flatten()
        self.batchnorm1d = nn.BatchNorm1d(105*105)
        self.linear1 = nn.Linear(105*105, 300)
        self.dropout = nn.Dropout(0.01)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(300, num_classes)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, image):
        output = self.conv2d(image)
        #print("convout {}".format(output.size()))
        output = self.maxpool2d(output)
        #print("poolout {}".format(output.size()))        
        output = self.flatten(output)
        output = self.batchnorm1d(output)
        #print("poolout {}".format(output.size()))        
        output = self.linear1(output)
        output = self.dropout(output)
        output = self.relu(output)
        output = self.linear2(output)
        return self.softmax(output)

In [41]:
# TRAIN

#parser = argparse.ArgumentParser()
#parser.add_argument("-c", "--config", help="configuration file", default="config.json")

#args = parser.parse_args()

#config = json.load(open(config))

#trainingdir = config["trainingdir"]
#testingdir = config["testingdir"]
#device = config["device"]

print("Running...")


traindataset = WikiArtDataset(trainingdir, device)
#testingdataset = WikiArtDataset(testingdir, device)

print(traindataset.imgdir)

the_image, the_label = traindataset[5]
print(the_image, the_image.size())

# the_showable_image = F.to_pil_image(the_image)
# print("Label of img 5 is {}".format(the_label))
# the_showable_image.show()


def train(epochs=3, batch_size=32, modelfile=None, device="cpu"):
    loader = DataLoader(traindataset, batch_size=batch_size, shuffle=True)

    model = WikiArtModel().to(device)
    optimizer = Adam(model.parameters(), lr=0.01)
    criterion = nn.NLLLoss().to(device)
    
    for epoch in range(epochs):
        print("Starting epoch {}".format(epoch))
        accumulate_loss = 0
        for batch_id, batch in enumerate(tqdm.tqdm(loader)):
            X, y = batch
            y = y.to(device)
            optimizer.zero_grad()
            output = model(X)
            loss = criterion(output, y)
            loss.backward()
            accumulate_loss += loss
            optimizer.step()

        print("In epoch {}, loss = {}".format(epoch, accumulate_loss))

    if modelfile:
        torch.save(model.state_dict(), modelfile)

    return model

Running...
Gathering files for /home/guserbto@GU.GU.SE/wikiart/train
...............................finished
/home/guserbto@GU.GU.SE/wikiart/train
tensor([[[171., 156., 148.,  ..., 255., 255., 255.],
         [138., 183., 233.,  ..., 255., 255., 255.],
         [152., 136., 132.,  ..., 255., 255., 255.],
         ...,
         [164., 163., 160.,  ..., 160., 161., 161.],
         [161., 159., 156.,  ..., 160., 164., 168.],
         [160., 156., 152.,  ..., 153., 159., 167.]],

        [[173., 157., 148.,  ..., 255., 255., 255.],
         [140., 184., 233.,  ..., 255., 255., 255.],
         [153., 137., 131.,  ..., 255., 255., 255.],
         ...,
         [177., 176., 173.,  ..., 180., 177., 174.],
         [174., 172., 169.,  ..., 180., 180., 181.],
         [173., 169., 165.,  ..., 173., 175., 180.]],

        [[124., 113., 112.,  ..., 255., 255., 255.],
         [ 93., 140., 199.,  ..., 255., 255., 255.],
         [109.,  95., 100.,  ..., 255., 255., 255.],
         ...,
         [13

In [None]:
model = train(args["epochs"], args["batch_size"], modelfile=args["modelfile"], device=device)

Starting epoch 0


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 418/418 [00:30<00:00, 13.49it/s]


In epoch 0, loss = 3886.69287109375
Starting epoch 1


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 418/418 [00:09<00:00, 42.92it/s]


In epoch 1, loss = 1133.0986328125
Starting epoch 2


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 418/418 [00:07<00:00, 53.40it/s]


In epoch 2, loss = 1151.4888916015625
Starting epoch 3


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 418/418 [00:06<00:00, 65.10it/s]


In epoch 3, loss = 1152.2574462890625
Starting epoch 4


 97%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎      | 406/418 [00:06<00:00, 61.44it/s]

In [None]:
# TEST

#parser = argparse.ArgumentParser()
#parser.add_argument("-c", "--config", help="configuration file", default="config.json")

#args = parser.parse_args()

#config = json.load(open(args.config))

#testingdir = config["testingdir"]
#device = config["device"]


print("Running...")

#traindataset = WikiArtDataset(trainingdir, device)
testingdataset = WikiArtDataset(testingdir, device)

def test(modelfile=None, device="cpu"):
    loader = DataLoader(testingdataset, batch_size=1)

    model = WikiArtModel()
    model.load_state_dict(torch.load(modelfile, weights_only=True))
    model = model.to(device)
    model.eval()

    predictions = []
    truth = []
    for batch_id, batch in enumerate(tqdm.tqdm(loader)):
        X, y = batch
        y = y.to(device)
        output = model(X)
        predictions.append(torch.argmax(output).unsqueeze(dim=0))
        truth.append(y)

    #print("predictions {}".format(predictions))
    #print("truth {}".format(truth))
    predictions = torch.concat(predictions)
    truth = torch.concat(truth)
    metric = metrics.MulticlassAccuracy()
    metric.update(predictions, truth)
    print("Accuracy: {}".format(metric.compute()))
    confusion = metrics.MulticlassConfusionMatrix(27)
    confusion.update(predictions, truth)
    print("Confusion Matrix\n{}".format(confusion.compute()))

In [None]:
test(modelfile=config["modelfile"], device=device)