In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import tqdm

from ml_zoo import ResNetConfig, ResNet

import wandb
import matplotlib.pyplot as plt

In [None]:
train_dataset = torchvision.datasets.CIFAR10(
    root="./data",
    train=True,
    download=True,
    transform=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
    ])
)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

test_dataset = torchvision.datasets.CIFAR10(
    root="./data",
    train=False,
    download=True,
    transform=torchvision.transforms.ToTensor(),
)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

In [None]:
model = nn.Sequential(
    nn.BatchNorm2d(3),
    nn.Conv2d(3, 32, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.BatchNorm2d(32),
    nn.Conv2d(32, 64, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.BatchNorm2d(64),
    nn.Conv2d(64, 128, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.BatchNorm2d(128),
    nn.Conv2d(128, 256, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Flatten(),
    nn.Linear(256 * 2 * 2, 128),
    nn.ReLU(),
    nn.Linear(128, 10),
)

print(f"Model has {sum(p.numel() for p in model.parameters()):,} parameters")

In [None]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5, verbose=True)

In [None]:
wandb.init(project="teach")
wandb.config.update({
    "model": "LeNet5", 
    "dataset": "CIFAR10",
    "optimizer": "SGD",
    "scheduler": "ReduceLROnPlateau"
})
wandb.watch(model)

In [None]:
@torch.no_grad()
def test(model):
    model.eval()
    correct = 0
    total = 0
    loss = 0
    for batch in test_loader:
        images, labels = batch
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        loss += F.cross_entropy(outputs, labels)

    accuracy = correct / total
    loss /= len(test_loader)
    return accuracy, loss

In [None]:
val_accuracy = 0
val_loss = 0
pbar = tqdm.tnrange(25, desc="Epochs")
for epoch in pbar:
    model.train()
    correct = []
    wrong = []
    classes = []

    for x, y in train_loader:
        optimizer.zero_grad()
        y_hat = model(x)
        loss = F.cross_entropy(y_hat, y)
        loss.backward()
        optimizer.step()

        wandb.log(
            {
                "train/loss": loss.item(),
                "train/accuracy": (y_hat.argmax(1) == y).float().mean().item(),
            },
        )

        pbar.set_postfix_str(
            f"loss={loss.item():.4f}, val acc={val_accuracy:.2%}, val loss={val_loss:.4f}"
        )

        # Log (x, y) pairs that were correct and wrong
        # correct.append((x[y_hat.argmax(1) == y], y[y_hat.argmax(1) == y]))
        # wrong.append((x[y_hat.argmax(1) != y], y[y_hat.argmax(1) != y]))
        # classes.append(y)

    

    # # Flip correct and create new dataset with correct and wrong
    # correct = torch.cat([x for x, _ in correct])
    # wrong = torch.cat([x for x, _ in wrong])
    # print(f"Flipping {len(correct)} correct images")
    # flip = torchvision.transforms.RandomHorizontalFlip(0.5)

    # correct = flip(correct)

    # train_loader = torch.utils.data.DataLoader(
    #     torch.utils.data.TensorDataset(
    #         torch.cat([correct, wrong]),
    #         torch.cat(classes),
    #     ),
    #     batch_size=64,
    #     shuffle=True,
    # )

    val_accuracy, val_loss = test(model)
    wandb.log({"val/accuracy": val_accuracy, "val/loss": val_loss})

    scheduler.step(loss)

In [None]:
wandb.finish()