In [1]:
import torch
from torch import nn, optim
from jcopdl.callback import Callback, set_config

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [3]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [8]:
bs = 32
crop_size = 224

train_transform = transforms.Compose([
    transforms.RandomRotation(10),
    transforms.RandomResizedCrop(crop_size, scale=(0.8,1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

test_transform = transforms.Compose([
    transforms.Resize(230),
    transforms.CenterCrop(crop_size),
    transforms.ToTensor(),
])

train_set = datasets.ImageFolder("data/train/", transform=train_transform)
trainloader = DataLoader(train_set, batch_size=bs, shuffle=True, num_workers=0)

test_set = datasets.ImageFolder("data/test/", transform=test_transform)
testloader = DataLoader(test_set, batch_size=bs, shuffle=True)

In [9]:
feature, target = next(iter(trainloader))
feature.shape

torch.Size([32, 3, 224, 224])

In [10]:
label2cat = train_set.classes
label2cat

['adidas', 'nike']

In [11]:
from jcopdl.layers import linear_block, conv_block

In [12]:
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            conv_block(3, 256),
            conv_block(256, 128),   
            conv_block(128, 64),
            conv_block(64, 16),
            nn.Flatten()
        )
        
        self.fc = nn.Sequential(
            linear_block(3136, 128, dropout=0.2),
            linear_block(128, 2, activation='lsoftmax')
        )
        
    def forward(self, x):
        return self.fc(self.conv(x))

In [13]:
config = set_config({
    "batch size": bs,
    "crop_size": crop_size
})

In [14]:
model = CNN().to(device)
criterion = nn.NLLLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001)
callback = Callback(model, config, outdir="model")

In [None]:
from tqdm.auto import tqdm

def loop_fn(mode, dataset, dataloader, model, criterion, optimizer, device):
    if mode == "train":
        model.train()
    elif mode == "test":
        model.eval()
    cost = correct = 0
    for feature, target in tqdm(dataloader, desc=mode.title()):
        feature, target = feature.to(device), target.to(device)
        output = model(feature)
        loss = criterion(output, target)
        
        if mode == "train":
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        
        cost += loss.item() * feature.shape[0]
        correct += (output.argmax(1) == target).sum().item()
    cost = cost / len(dataset)
    acc = correct / len(dataset)
    return cost, acc

In [None]:
while True:
    train_cost, train_score = loop_fn("train", train_set, trainloader, model, criterion, optimizer, device)
    with torch.no_grad():
        test_cost, test_score = loop_fn("test", test_set, testloader, model, criterion, optimizer, device)
    
    # Logging
    callback.log(train_cost, test_cost, train_score, test_score)

    # Checkpoint
    callback.save_checkpoint()
        
    # Runtime Plotting
    callback.cost_runtime_plotting()
    callback.score_runtime_plotting()
    
    # Early Stopping
    if callback.early_stopping(model, monitor="test_score"):
        callback.plot_cost()
        callback.plot_score()
        break

In [None]:
feature, target = next(iter(testloader))
feature, target = feature.to(device), target.to(device)

In [None]:
with torch.no_grad():
    model.eval()
    output = model(feature)
    preds = output.argmax(1)
preds

In [None]:
fig, axes = plt.subplots(6,6, figsize=(24,24))
for image, label,pred, ax in zip(feature,target,preds,axes.flatten()):
    ax.imshow(image.permute(1,2,0).cpu())
    font = {"color": 'r'} if label != pred else {"color": 'g'}
    label,pred = label2cat[label.item()], label2cat[pred.item()]
    ax.set_title(f"L : {label} | P : {pred}", fontdict=font);
    ax.axis('off')