In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from convnext import ConvNeXt
from fastprogress.fastprogress import master_bar, progress_bar
from torchvision.transforms import RandomResizedCrop, CenterCrop, Resize, RandomHorizontalFlip, Compose
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

  def convnext_tiny(pretrained=False,in_22k=False, **kwargs):
  def convnext_small(pretrained=False,in_22k=False, **kwargs):
  def convnext_base(pretrained=False, in_22k=False, **kwargs):
  def convnext_large(pretrained=False, in_22k=False, **kwargs):
  def convnext_xlarge(pretrained=False, in_22k=False, **kwargs):


In [2]:
train_dataset = load_dataset("imagenet-1k",split="train",trust_remote_code=True)
valid_dataset = load_dataset("imagenet-1k",split="validation",trust_remote_code=True)
rgb_train = train_dataset.filter(lambda s: s["image"].mode == "RGB").with_format("torch")
rgb_valid = valid_dataset.filter(lambda s: s["image"].mode == "RGB").with_format("torch")

Loading dataset shards:   0%|          | 0/257 [00:00<?, ?it/s]

In [3]:
def custom_collate_fn(batch):
    train_transforms = Compose(
        [
            RandomResizedCrop(256, scale=(0.5,1.0), ratio=(0.8, 1.25)),
            RandomHorizontalFlip(),
        ]
    )
    images = torch.stack([train_transforms(item['image'].permute(2,0,1).to(torch.float)/255) for item in batch])
    labels = torch.tensor([item['label'] for item in batch], dtype=torch.int64)
    return images, labels

val_transforms = Compose(
    [
        Resize(256,antialias=True),
        CenterCrop(256),
    ]
)

In [4]:
classifier = ConvNeXt(in_chans=3,
                      num_classes=1000,
                      depths=[3, 3, 9, 3],
                      dims=[96, 192, 384, 768],
                      drop_path_rate=0.,
                      layer_scale_init_value=1e-6,
                      head_init_scale=1.).cuda()

In [5]:
epochs = 50
batch_size = 128
accumulation_steps = 1024//batch_size

loss_function = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = AdamW(classifier.parameters(), lr=0.001, weight_decay=1e-3)
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)

train_loss = []
test_loss = []
val_accuracy = []

In [None]:
mb = master_bar(range(1, epochs + 1))

for epoch in mb:
    
    dataloader_train = DataLoader(rgb_train, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=12, collate_fn=custom_collate_fn)
    dataloader_valid = DataLoader(rgb_valid, batch_size=1, shuffle=False, drop_last=False)

    # Training
    classifier.train()
    optimizer.zero_grad()
    running_loss = 0.0
    for i_batch, (x, target) in enumerate(progress_bar(dataloader_train, parent=mb)):
        x = x.cuda()
        outputs = classifier(x)
        loss = loss_function(outputs, target.cuda())
        loss.backward()
        print(loss.item(), end='\r')
        running_loss += loss.item()

        if (i_batch + 1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

    train_epoch_loss = running_loss / len(dataloader_train)
    train_loss.append(train_epoch_loss)
    
    # Validation
    classifier.eval()
    running_loss_val = 0.0
    correct_predictions = 0
    total_predictions = 0
    with torch.no_grad():
        for inputs in dataloader_valid:
            x = inputs['image'].permute(0,3,1,2).to(torch.float)/255
            x = val_transforms(x).cuda()
            targets = inputs['label'].to(torch.int64).cuda()
            outputs = classifier(x)
            loss = loss_function(outputs, targets)
            running_loss_val += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct_predictions += (predicted == targets).sum().item()
            total_predictions += targets.size(0)
    val_epoch_loss = running_loss_val / len(dataloader_valid)
    test_loss.append(val_epoch_loss)
    epoch_accuracy = correct_predictions / total_predictions
    print(f'Epoch: {epoch}, Accuracy: {epoch_accuracy:.4f}')
    val_accuracy.append(epoch_accuracy)
    graphs = [
        [range(1, epoch + 1), train_loss],  # Training Loss
        [range(1, epoch + 1), test_loss],  # Validation Loss
    ]
    x_bounds = [1, epoch]
    y_bounds = [min(min(train_loss + test_loss) - 0.05, 0), max(max(train_loss + test_loss) + 0.05, 1)]
    mb.update_graph(graphs, x_bounds, y_bounds)
    scheduler.step()

6.5663051605224615

In [None]:
checkpoint_filename = f'convnext_imagenet1k_e{epoch-1}.pt'
checkpoint = {
    'model_state_dict': classifier.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
    'epoch': epoch,
    'train_loss': train_loss,
    'test_loss': test_loss,
    'val_accuracy': val_accuracy
}
torch.save(checkpoint, checkpoint_filename)