In [None]:
import torch
from torch import nn
import torchvision
from torch.utils import data
from tqdm import tqdm
import numpy as np
from datetime import datetime
import json
import wandb
import os

from models import MoCo
from loader import TwoCropsTransform, GaussianBlur

import sys
sys.path.append("../")
from dataset_utils import ImageNet, Places365, ArtPlaces, ArtPlacesTimesN

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if device.type == "cuda":
    print(torch.cuda.get_device_name())

In [None]:
LOG_RUN = True

DATASET = "ArtPlacesTimesN"
N = None
match DATASET:
    case "ArtPlacesTimesN":
        N = 12
BATCH_SIZE = 64 # 256
EPOCHS = 120 # 200
CRITERION = "cross_entropy"
OPTIMIZER = "sgd"
SCHEDULER = "step"
match SCHEDULER:
    case "step":
        SCHEDULE = [40, 80] # [120, 160]
    case "cosine":
        SCHEDULE = None

MOMENTUM = 0.9
WEIGHT_DECAY = 1e-4
LR = 0.01 # 0.03

MODEL="resnet50" # resnet50
WEIGHTS="MOCO"
DIM = 128 # 128
K = 4544 # 65536
M = 0.999
T = 0.07
MLP = False

In [None]:
config = {
    # "architecture": ARCHITECTURE,
    # "pretrained": PRETRAINED,
    "dataset": DATASET,
    "batch_size": BATCH_SIZE,
    "epochs": EPOCHS,
    "learning_rate": LR,
    # "hidden_units": HIDDEN_UNITS,
    # "ssim_loss": USE_SSIM_LOSS,
    # "ssim_loss_scale": SSIM_SCALE,
    # "perceptual_loss_architecture": PERCEPTUAL_LOSS_ARCHITECTURE,
    # "perceptual_loss_scale": SCALE,
    "optimizer": OPTIMIZER,
    # "optimizer_momentum": MOMENTUM,
    # "optimizer_weight_decay": WEIGHT_DECAY,
    "scheduler": SCHEDULER,
    "schedule": SCHEDULE,
    # "scheduler_modulo": SCHEDULE,
    # "scheduler_gamma": GAMMA,
    "model": MODEL,
    "weights": WEIGHTS,
    "dim": DIM,
    "k": K,
    "m": M,
    "t": T,
    "mlp": MLP,
}

In [None]:
if LOG_RUN:
    wandb.login()

### Datensatz

In [None]:
NUM_WORKERS = 4

transform = torchvision.transforms.Compose([
    torchvision.transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
    torchvision.transforms.RandomGrayscale(p=0.2),
    torchvision.transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

transform = TwoCropsTransform(transform)

match DATASET:
    case "ImageNet":
        dataset = ImageNet(root=r"C:\Users\mariu\Documents\Studium\Praktikum\ImageNet_Subset", transform=transform)
        dataset_val = ImageNet(root=r"C:\Users\mariu\Documents\Studium\Praktikum\ImageNet_Subset", split="val", transform=transform)
    case "Places365":
        dataset = Places365(root=r"C:\Users\mariu\Documents\Studium\Praktikum\Places365_Subset", transform=transform)
        dataset_val = Places365(root=r"C:\Users\mariu\Documents\Studium\Praktikum\Places365_Subset", transform=transform, split="val")
    case "ArtPlaces":
        dataset = ArtPlaces(root=r"C:\Users\mariu\Documents\Development\Datasets\ArtPlaces_13371280", transform=transform)
        dataset_val = ArtPlaces(root=r"C:\Users\mariu\Documents\Development\Datasets\ArtPlaces_13371280", transform=transform, split="val")
    case "ArtPlacesTimesN":
        dataset = ArtPlacesTimesN(N, root=r"C:\Users\mariu\Documents\Development\Datasets\ArtPlaces_13371280", transform=transform)
        dataset_val = ArtPlaces(root=r"C:\Users\mariu\Documents\Development\Datasets\ArtPlaces_13371280", transform=transform, split="val")

data_loader = data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, drop_last=True)
data_loader_val = data.DataLoader(dataset_val, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, drop_last=True)

### Modell

In [None]:
moco = MoCo(
    transform=None,
    model=MODEL,
    weights=WEIGHTS,
    dim = DIM,
    K = K,
    m = M,
    T = T,
    mlp = MLP
)

moco = moco.train()
moco = moco.to(device)

### Modell trainieren

In [None]:
if LOG_RUN:
    wandb.init(
        # set the wandb project where this run will be logged
        project="moco",
        dir=r"C:\Users\mariu\Desktop",

        # track hyperparameters and run metadata
        config=config
    )

In [None]:
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

In [None]:
dest = os.path.join(r"C:\Users\mariu\Documents\Studium\Praktikum\Gewichte", "moco", MODEL.lower() + "_" + DATASET.lower() + "_" + datetime.today().strftime('%Y%m%d_%H%M%S'))
os.makedirs(dest)

with open(os.path.join(dest, "constants.json"), "w") as file:
    json.dump(config, file)

def save_model(epoch=0):
    torch.save(moco.state_dict(), os.path.join(dest, "state_dict_" + str(epoch) + ".pt"))

In [None]:
# Criterion

match CRITERION:
    case "cross_entropy":
        criterion = nn.CrossEntropyLoss().to(device)


# Optimizer

match OPTIMIZER:
    case "sgd":
        optimizer = torch.optim.SGD(
            moco.parameters(),
            LR,
            momentum=MOMENTUM,
            weight_decay=WEIGHT_DECAY,
        )


# Scheduler

match SCHEDULER:
    case "step":
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
    case "cosine":
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

In [None]:
for epoch in range(EPOCHS):
    if SCHEDULE is None or epoch in SCHEDULE:
        scheduler.step()


    # Train

    losses = []
    # top1 = []
    # top5 = []
    
    tqdm_data_loader = tqdm(data_loader, unit="batch")
    tqdm_data_loader.set_description(f"Epoch {epoch+1}/{EPOCHS}")
    for i, (images, _) in enumerate(tqdm_data_loader):
        images[0] = images[0].to(device)
        images[1] = images[1].to(device)
    
        output, target = moco(im_q=images[0], im_k=images[1])
        loss = criterion(output, target)

        # acc1, acc5 = accuracy(output, target, topk=(1, 5))

        if (i+1) % 10 == 0 and LOG_RUN:
            wandb.log({
                "loss": loss.item(),
                # "acc1": acc1[0],
                # "acc5": acc5[0],
                "learning_rate": scheduler.get_last_lr()[-1], 
                "epoch": epoch + 1,
                "batch": i + 1,
            })

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())
        # top1.append(acc1[0])
        # top5.append(acc5[0])

        postfix = {
            "Loss": loss.item(),
            # "Acc@1": acc1[0],
            # "Acc@5": acc5[0],
        }

        tqdm_data_loader.set_postfix(postfix)
    

    # Val

    val_losses = []
    # val_top1 = []
    # val_top5 = []

    moco.eval()
    with torch.no_grad():
        tqdm_data_loader_val = tqdm(data_loader_val, unit="batch")
        tqdm_data_loader_val.set_description(f"Epoch {epoch+1}/{EPOCHS}")
        for i, (images, _) in enumerate(tqdm_data_loader_val):
            images[0] = images[0].to(device)
            images[1] = images[1].to(device)

            output, target = moco(im_q=images[0], im_k=images[1])
            loss = criterion(output, target)

            # acc1, acc5 = accuracy(output, target, topk=(1, 5))

            val_losses.append(loss.item())
            # val_top1.append(acc1[0])
            # val_top5.append(acc5[0])
    moco.train()

    if LOG_RUN:
        wandb.log({
            "loss_avg": np.mean(losses),
            # "acc1_avg": np.mean(top1),
            # "acc5_avg": np.mean(top5),
            "val_loss_avg": np.mean(val_losses),
            # "val_acc1_avg": np.mean(val_top1),
            # "val_acc5_avg": np.mean(val_top5),
        })

    if (epoch+1) % 20 == 0:
        save_model(epoch=epoch + 1)

In [None]:
if LOG_RUN:
    wandb.finish()