In [1]:
import torch.cuda

from utils import *

device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_default_device(device)

In [2]:
def iterate(dataset):
    while True:
        for batch in dataset:
            yield batch

In [3]:
class AUC(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, yPred, yTrue):
        yPred, yTrue = yPred.detach(), yTrue.detach()

        sortedPred, indices = torch.sort(yPred, dim=0, descending=True)
        sortedLabels = yTrue.float()[indices]

        tp = torch.cumsum(sortedLabels, dim=0)
        fp = torch.cumsum(1 - sortedLabels, dim=0)

        tpr = tp / (tp[-1] + 1e-8)
        fpr = fp / (fp[-1] + 1e-8)

        auc = torch.trapz(tpr, fpr)

        return auc.item()

In [4]:
def trainModel(model, loaders, config):
    optimizer = torch.optim.Adam(model.parameters(), lr=config.learningRate)
    try:
        objective = torch.nn.CrossEntropyLoss()
        criterion = AUC()

        testIter = iterate(loaders["val"])

        client = Client("127.0.0.1", 12954)

        for epoch in range(config.epochs):
            progress = 0
            for inputs, targets in loaders["train"]:
                model.train()
                optimizer.zero_grad()
                outputs = model(inputs.to(device))

                loss = objective(outputs.squeeze(), targets.to(device))
                loss.backward()
                optimizer.step()

                trainLoss = loss.item()
                # trainAUC = criterion(outputs.squeeze(), targets.to(device).float())

                inputs1, targets1 = next(testIter)
                with torch.no_grad():
                    model.eval()
                    outputs1 = model(inputs1.to(device))
                    loss1 = objective(outputs1.squeeze(), targets1.to(device))

                    testLoss = loss1.item()
                    # testAUC = criterion(outputs1.squeeze(), targets1.to(device).float())

                client.send("Train Loss", trainLoss)
                client.send("Test Loss", testLoss)
                # client.send("Train AUC", trainAUC)
                # client.send("Test AUC", testAUC)

                progress += 1
                print(f"\r{epoch + 1} | {progress}/{len(loaders['train'])} | Train Loss: {trainLoss:.2f} | Test Loss: {testLoss:.2f}", end="")

    except KeyboardInterrupt:
        pass
    return model, optimizer

In [None]:
resume = os.path.join("checkpoints", "PCAFormer")
# resume = None
modelResolution = {"ResNet50": ResNet50, "SwinTransformerV2Tiny": SwinTransformerV2Tiny, "PCAFormer": PCAFormer}

if resume is not None:
    print(f"Resuming {os.path.basename(resume)} training")
    modelClass = modelResolution[os.path.basename(resume)]
    config = Config().load(os.path.join(resume, "config.json"))

    model = modelClass(config.model)
    model.load_state_dict(torch.load(os.path.join(resume, "checkpoint.pt")))

else:
    config = Config().load(os.path.join("configs", "config.json"))

    modelClass = PCAFormer
    model = modelClass(config.model)


if "cacheDir" in config:
    os.environ["KAGGLEHUB_CACHE"] = config.cacheDir

dataDir = download_dataset("dimensi0n/imagenet-256")

config.dataset.dataDir = dataDir
loaders = get_dataloaders(config, device)

print(f"Model has {sum([p.numel() for p in model.parameters()])} parameters")

model, optimizer = trainModel(model, loaders, config)

directory = os.path.join("checkpoints", modelClass.__name__)
os.makedirs(directory, exist_ok=True)
torch.save(model.state_dict(), os.path.join(directory, "checkpoint.pt"))
torch.save(optimizer.state_dict(), os.path.join(directory, "optimizer.pt"))
config.save(os.path.join(directory, "config.json"))

Resuming PCAFormer training
KaggleHub dataset path: F:/.cache/kagglehub\datasets\dimensi0n\imagenet-256\versions\1
Model has 38916072 parameters
3 | 2127/6748 | Train Loss: 7.59 | Test Loss: 6.30