# CIFAR 100

In [None]:
from torch.utils.tensorboard import SummaryWriter
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
import time

  import pynvml  # type: ignore[import]


In [10]:
# Hyperparameters
batch_size = 64
learning_rate=1e-4
epochs = 10

time_str = time.strftime("%b_%d_%H%M").lower()
run_path = f"./runs/{time_str}"
w = SummaryWriter(run_path)

device = "cpu"

In [3]:
class NeuralNetwork(nn.Module):
    def __init__(self, dropout=0.5):
        super().__init__()
        self.conv_layer1 = nn.Conv2d(3,6,(3,3), padding='same')
        self.batch_norm1 = nn.BatchNorm2d(6)
        self.max_pool1 = nn.MaxPool2d((2,2))
        self.conv_layer2 = nn.Conv2d(6,12,(3,3), padding='same')
        self.batch_norm2 = nn.BatchNorm2d(12)
        self.conv_layer3 = nn.Conv2d(12,6,(3,3), padding='same')
        self.batch_norm3 = nn.BatchNorm2d(6)
        self.max_pool2 = nn.MaxPool2d((2,2))
        self.dense1 = nn.Linear(384, 256)
        self.dense2 = nn.Linear(256, 100)
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, x:torch.Tensor):
        x = self.conv_layer1(x) # 6x32x32
        x = self.batch_norm1(x)
        x = F.relu(x)
        x = self.max_pool1(x) # 6x16x16
        x = self.conv_layer2(x) #12x16x16
        x = self.batch_norm2(x)
        x = F.relu(x)
        x = self.conv_layer3(x) #16x16x16
        x = self.batch_norm3(x)
        x = F.relu(x)
        x = self.max_pool2(x) #6x8x8
        x = torch.flatten(x,1) #4096 features
        x = self.dense1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.dense2(x) # pass logits into loss function
        return x

In [4]:
train_transforms = transforms.Compose([
    transforms.RandomCrop(32,padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])

test_transforms = transforms.Compose([
    transforms.ToTensor()
])

train_dataset = datasets.CIFAR100(
    root=('./data'),
    train=True,
    download=True,
    transform=train_transforms
)
test_dataset = datasets.CIFAR100(
    root=('./data'),
    train=False,
    download=True,
    transform=test_transforms
)

train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2
)
test_loader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2
)

images,labels = next(iter(train_loader))
img_grid = torchvision.utils.make_grid(images)
w.add_image('cifar100_images', img_grid)

  import pynvml  # type: ignore[import]
  import pynvml  # type: ignore[import]


In [5]:
class NeuralNetwork(nn.Module):
    def __init__(self, dropout=0.5):
        super().__init__()
        self.conv_layer1 = nn.Conv2d(3,6,(3,3), padding='same')
        self.batch_norm1 = nn.BatchNorm2d(6)
        self.max_pool1 = nn.MaxPool2d((2,2))
        self.conv_layer2 = nn.Conv2d(6,12,(3,3), padding='same')
        self.batch_norm2 = nn.BatchNorm2d(12)
        self.conv_layer3 = nn.Conv2d(12,6,(3,3), padding='same')
        self.batch_norm3 = nn.BatchNorm2d(6)
        self.max_pool2 = nn.MaxPool2d((2,2))
        self.dense1 = nn.Linear(384, 256)
        self.dense2 = nn.Linear(256, 100)
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, x:torch.Tensor):
        x = self.conv_layer1(x) # 6x32x32
        x = self.batch_norm1(x)
        x = F.relu(x)
        x = self.max_pool1(x) # 6x16x16
        x = self.conv_layer2(x) #12x16x16
        x = self.batch_norm2(x)
        x = F.relu(x)
        x = self.conv_layer3(x) #16x16x16
        x = self.batch_norm3(x)
        x = F.relu(x)
        x = self.max_pool2(x) #6x8x8
        x = torch.flatten(x,1) #4096 features
        x = self.dense1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.dense2(x) # pass logits into loss function
        return x

In [None]:
# actual model training

model = NeuralNetwork(dropout=0.5)
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

sample_images, _ = next(iter(train_loader))
sample_images = sample_images.to(device)
w.add_graph(model, sample_images)

for epoch in range(epochs):
    # train
    model.train()
    total_train_loss = 0
    for images,labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        logits = model(images)
        loss = loss_fn(logits, labels)
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()
    
    # test
    model.eval()
    correct = 0
    total = 0
    total_test_loss = 0
    for images,labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        logits = model(images)
        loss = loss_fn(logits, labels)
        total_test_loss += loss.item()
        predicted_classes = torch.argmax(logits, dim=1)
        comparison = (predicted_classes == labels)
        correct += comparison.sum().item()
        total += labels.size(0)
        
    # write to summary writer
    w.add_scalar("Train Loss", total_train_loss, epoch)
    w.add_scalar("Test Loss", total_test_loss, epoch)
    w.add_scalar("Test Accuracy", correct/total, epoch)
    for name, param in model.named_parameters():
        w.add_histogram(f'Weights/{name}', param.data, epoch)
        if param.grad is not None:
            w.add_histogram(f'Gradients/{name}', param.grad.data, epoch)
        
print("training complete")

  import pynvml  # type: ignore[import]
  import pynvml  # type: ignore[import]
  import pynvml  # type: ignore[import]
  import pynvml  # type: ignore[import]
  import pynvml  # type: ignore[import]
  import pynvml  # type: ignore[import]
  import pynvml  # type: ignore[import]
  import pynvml  # type: ignore[import]
  import pynvml  # type: ignore[import]
  import pynvml  # type: ignore[import]
  import pynvml  # type: ignore[import]
  import pynvml  # type: ignore[import]
  import pynvml  # type: ignore[import]
  import pynvml  # type: ignore[import]
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/multiprocessing/spawn.py", line 126, in _main
    self = 

KeyboardInterrupt: 

In [None]:
MODEL_PATH = f"{run_path}.pt"
torch.save(model.state_dict(), MODEL_PATH)
w.close()

## Takeaways

So in this module, the difficult part was setting the distributed data parallel and then thinking about everything that a Python dev never thinks about: processes, threads, cpu cores, memory usage, etc. I think I still have a lot to learn in that regard, and the challenging part about this more general 'computer science stuff' is that I'll probably never know when I fully understand it. For instance, it's easy to know when you grasp a neural network architecture like a transformer. But there are so many things to consider when thinking about the operating system, computer hardware, library compatibility, and cpu processes that I'm not sure what kinds of resources will provide the full grasp. The pytorch method of imperatively putting tensors in "cpu land" or "gpu land" is deeply unfulfilling to me and I think the pains I encounted while playing around with distributed training partially describes why.

Speaking of Pytorch's hardware abstractions and choices for ddp, I don't like how every process needs to be self-aware with enough flexibility to do their own thing. After going through jax, keras, and pytorch, I think I want to pursue the future modules with jax. I still miss the days where vmap abstracted the idea of a batch and instead abstracted it to a dimension that the compiler takes care of, and after a quick search about pmap, I expect a similar magic to occur if I ever have to scale these models. I also enjoyed how the math felt completely exposed and tied to jax, and I didn't have to worry about conventions and stylistic choices the api made for each of their modules.