In [17]:
import wandb
import json
import torchsummary
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler

In [18]:
with open("../../data/api-key.json") as key:
    wandb.login(key=json.load(key)['wandb'])



In [19]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms

import numpy as np

In [20]:
class Cifar10(Dataset):
    def __init__(self, path, train=True, transform=None) -> None:
        super().__init__()
        self.path = path
        self.train = train
        self.transform = transform

        images = []
        labels = []
        if train:
            for i in range(1, 6):
                data = self.unpickle(path + f"/data_batch_{i}")
                images.append(data[b"data"].reshape(-1, 3, 32, 32))
                labels.append(data[b"labels"])
            self.images = np.concatenate(images).transpose((0, 2, 3, 1))
            self.labels = np.concatenate(labels)
        else:
            data = self.unpickle(path + f"/test_batch")
            self.images = data[b"data"].reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1))
            self.labels = data[b"labels"]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        images = self.images[index]
        lables = self.labels[index]
        if self.transform is not None:
            images = self.transform(images)
        return images, lables
    
    def unpickle(self, path):
        import pickle
        with open(path, 'rb') as fo:
            data = pickle.load(fo, encoding='bytes')
        return data

In [21]:
class Net(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, 3),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3),
            nn.ReLU()
        )
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(128*28*28, 10)
        # self.fc = nn.LazyLinear(10)
    
    def forward(self, x):
        x = self.conv(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x

In [22]:
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"device: {device}")

device: mps


In [23]:
testmodel = Net()
torchsummary.summary(testmodel, (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 30, 30]           1,792
              ReLU-2           [-1, 64, 30, 30]               0
            Conv2d-3          [-1, 128, 28, 28]          73,856
              ReLU-4          [-1, 128, 28, 28]               0
           Flatten-5               [-1, 100352]               0
            Linear-6                   [-1, 10]       1,003,530
Total params: 1,079,178
Trainable params: 1,079,178
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 3.18
Params size (MB): 4.12
Estimated Total Size (MB): 7.30
----------------------------------------------------------------


In [24]:
config = {
    "lr": 1e-3,
    "batch_size": 32,
    "epochs": 30
}
wandb.init(project="study", config=config)

0,1
acc,▁▂▂▃▃▄▄▅▅▅▅▆▆▆▇▆▇▇▇▇█▇████████
loss,█▇▆▅▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁

0,1
acc,56.42971
loss,1.14221


In [25]:
train_valid_dataset = Cifar10(path="../../data/cifar-10-batches-py", train=True, transform=transforms.ToTensor())
test_dataset = Cifar10(path="../../data/cifar-10-batches-py", train=False, transform=transforms.ToTensor())

train_dataset, valid_dataset = random_split(train_valid_dataset, [40000, 10000])
train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=config["batch_size"], shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=config["batch_size"], shuffle=False)

model = Net().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=config["lr"])

In [26]:
def train(train_dataloader, valid_dataloader, model, criterion, optimizer, config):
    for epoch in range(1, config["epochs"] + 1):
        epoch_loss = 0
        model.train()
        for image, label in train_dataloader:
            image, label = image.to(device), label.to(device)

            pred = model(image)
            loss = criterion(pred, label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        with torch.no_grad():
            acc = 0
            model.eval()
            for image, label in valid_dataloader:
                image, label = image.to(device), label.cpu()
                pred = model(image).cpu()
                acc += (pred.argmax(1) == label).float().sum().item()
        
        print(f"epoch: {epoch:3}, loss: {epoch_loss / len(train_dataloader):3f}, acc: {acc / len(test_loader) / config['batch_size'] * 100:.2f}%")
        wandb.log({
            "loss": epoch_loss / len(train_dataloader),
            "acc": acc / len(test_loader) / config["batch_size"] * 100
            })
        tune.report(acc=acc)

In [27]:
def test(dataloader, model, device):
    with torch.no_grad():
        acc = 0
        model.eval()
        for image, label in dataloader:
            image, label = image.to(device), label.cpu()
            pred = model(image).cpu()
            acc += (pred.argmax(1) == label).float().sum().item()
        print(f"acc: {acc / len(test_loader) / config['batch_size'] * 100:.2f}%")

In [28]:
def temp(config):
    train_valid_dataset = Cifar10(path="../../data/cifar-10-batches-py", train=True, transform=transforms.ToTensor())

    train_dataset, valid_dataset = random_split(train_valid_dataset, [40000, 10000])
    train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)
    valid_loader = DataLoader(valid_dataset, batch_size=config["batch_size"], shuffle=True)

    model = Net().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=config["lr"])
    train(train_loader, valid_loader, model, criterion, optimizer, config)
    

In [32]:
search_space = {
    "lr": tune.sample_from(lambda spec: 10 ** (-10 * np.random.rand())),
    "batch_size": 32,
    "epochs": 10
}
tunner = tune.Tuner(temp, param_space=search_space)

In [None]:
def train_cifar(config, checkpoint_dir, data_dir):
    train_valid_dataset = Cifar10(path="../../data/cifar-10-batches-py", train=True, transform=transforms.ToTensor())
    test_dataset = Cifar10(path="../../data/cifar-10-batches-py", train=False, transform=transforms.ToTensor()) 

    train_dataset, valid_dataset = random_split(train_valid_dataset, [40000, 10000])
    train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)
    valid_loader = DataLoader(valid_dataset, batch_size=config["batch_size"], shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=config["batch_size"], shuffle=False)  

    model = Net().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=config["lr"])

    train()


In [13]:
train(train_loader, valid_loader, model, criterion, optimizer, device, config["epochs"])
test(test_loader, model, device)

epoch:   1, loss: 2.059463, acc: 33.82%
epoch:   2, loss: 1.865511, acc: 36.72%
epoch:   3, loss: 1.781011, acc: 38.66%
epoch:   4, loss: 1.719234, acc: 40.85%
epoch:   5, loss: 1.669013, acc: 41.22%
epoch:   6, loss: 1.624945, acc: 44.56%
epoch:   7, loss: 1.584314, acc: 44.28%
epoch:   8, loss: 1.547626, acc: 45.46%
epoch:   9, loss: 1.513358, acc: 47.04%
epoch:  10, loss: 1.481849, acc: 47.63%
epoch:  11, loss: 1.453779, acc: 48.05%
epoch:  12, loss: 1.428120, acc: 49.56%
epoch:  13, loss: 1.403040, acc: 49.17%
epoch:  14, loss: 1.380928, acc: 50.42%
epoch:  15, loss: 1.358372, acc: 51.97%
epoch:  16, loss: 1.336970, acc: 50.11%
epoch:  17, loss: 1.315795, acc: 53.12%
epoch:  18, loss: 1.297201, acc: 53.37%
epoch:  19, loss: 1.279061, acc: 54.19%
epoch:  20, loss: 1.261036, acc: 53.42%
epoch:  21, loss: 1.246961, acc: 54.94%
epoch:  22, loss: 1.230391, acc: 54.85%
epoch:  23, loss: 1.218494, acc: 55.10%
epoch:  24, loss: 1.207496, acc: 55.60%
epoch:  25, loss: 1.194753, acc: 56.17%
