In [None]:
import os
import torch
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import CIFAR100
from torch.utils.data import DataLoader
from efficientnet_pytorch import EfficientNet

dataset_dir = '/D/datasets/CIFAR'
batch_size = 32
NUM_WORKERS = 4
num_epochs = 25

train_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(224, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
test_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

if os.path.isdir(dataset_dir):
    train_data = CIFAR100(root=dataset_dir, train=True, transform=train_transform)
    test_data = CIFAR100(root=dataset_dir, train=False, transform=test_transform)
else:
    os.mkdir(dataset_dir)
    train_data = CIFAR100(root=dataset_dir, train=True, download=True, transform=train_transform)
    test_data = CIFAR100(root=dataset_dir, train=False, download=True, transform=test_transform)

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS, drop_last=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=NUM_WORKERS, drop_last=False)

efficientnet = EfficientNet.from_pretrained('efficientnet-b0')
efficientnet._fc = torch.nn.Linear(efficientnet._fc.in_features, 100)  # Adjust the number of output classes to 100 for CIFAR100

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_ = efficientnet.to(device)


In [None]:
optimizer = optim.Adam(efficientnet.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

def train(epoch):
    efficientnet.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = efficientnet(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

def test():
    efficientnet.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = efficientnet(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n')

for epoch in range(1, num_epochs + 1):
    train(epoch)
    test()
    scheduler.step()

In [None]:
test_model(efficientnet, test_loader)


In [None]:
torch.save(efficientnet, '/D/models/cifar_efficientnet.pkl')

In [None]:
efficientnet._fc = torch.nn.Identity()
cifar_logits_trian_dataloader = get_logits_dataloader(efficientnet, train_loader, batch_size=32, whiten=False)
cifar_logits_val_dataloader = get_logits_dataloader(efficientnet, test_loader, batch_size=32, whiten=False)

In [None]:
with open('/D/datasets/CIFAR/logits_dataloaders/logits_train_dataloader.pkl', 'wb') as f:
    pickle.dump(cifar_logits_trian_dataloader, f)
with open('/D/datasets/CIFAR/logits_dataloaders/logits_test_dataloader.pkl', 'wb') as f:
    pickle.dump(cifar_logits_val_dataloader, f)