In [1]:
#Import necessary packages
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image

#ConcatDataset and Subset are possibly useful when doing semi-supervised
from torch.utils.data import ConcatDataset, DataLoader, Subset
from torchvision.datasets import DatasetFolder

#This is for the progress bar
from tqdm import tqdm

In [2]:
#It is important to do data augmentation in training.
#However, not every augmentation is useful
#Please think about what kind of augmentation is helpful for food recognition
train_tfm = transforms.Compose([
    transforms.Resize((256, 256)),
    #You may add some transforms here
    transforms.ColorJitter(0.1, 0.1, 0.1),
    transforms.RandomRotation(20),
    transforms.RandomHorizontalFlip(p = 0.5),
    transforms.ToTensor(),
    #transforms.Normalize([0.554, 0.450, 0.343], [0.231, 0.241, 0.241]),
    #transforms.RandomErasing()
    
])
#train_tfm = transforms.Compose([
#    transforms.RandomResizedCrop((128, 128)),
#    transforms.RandomChoice(
#        [transforms.AutoAugment(),
#        transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),
#        transforms.AutoAugment(transforms.AutoAugmentPolicy.SVHN)]
#    ),
#    transforms.RandomHorizontalFlip(p=0.5),
#    transforms.ColorJitter(brightness=0.5),
#    transforms.RandomRotation(5),
#    transforms.ToTensor(),
#])

# We don't need augmentations in testing and validation.
# All we need here is to resize the PIL image and transform it into Tensor.
test_tfm = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    #transforms.Normalize([0.554, 0.450, 0.343], [0.231, 0.241, 0.241])
])

In [3]:
# Batch size for training, validation, and testing.
# A greater batch size usually gives a more stable gradient.
# But the GPU memory is limited, so please adjust it carefully.
batch_size = 128

#Construct datasets
#The augment "loader" tells how torchvision reads the data
train_set = DatasetFolder("food-11/training/labeled", loader=lambda x: Image.open(x), extensions="jpg", transform=train_tfm)

valid_set = DatasetFolder("food-11/validation", loader=lambda x: Image.open(x), extensions="jpg", transform=test_tfm)
unlabeled_set = DatasetFolder("food-11/training/unlabeled", loader=lambda x: Image.open(x), extensions="jpg", transform=train_tfm)
test_set = DatasetFolder("food-11/testing", loader=lambda x: Image.open(x), extensions="jpg", transform=test_tfm)

# Construct data loaders.
#train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers = 8)
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=True, drop_last = True, num_workers = 8)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

Debugger :  class -> idx 
{'00': 0, '01': 1, '02': 2, '03': 3, '04': 4, '05': 5, '06': 6, '07': 7, '08': 8, '09': 9, '10': 10}
Debugger :  class -> idx 
{'00': 0, '01': 1, '02': 2, '03': 3, '04': 4, '05': 5, '06': 6, '07': 7, '08': 8, '09': 9, '10': 10}
Debugger :  class -> idx 
{'00': 0}
Debugger :  class -> idx 
{'00': 0}


In [4]:
class pDataset(torch.utils.data.Dataset):
    def __init__(self, X, y):
        self.data = X
        #y = y.astype(np.int)
        self.label = y
    def __getitem__(self,idx): # if the index is idx, what will be the data?
        #print(self.label)
        return self.data, self.label
    def __len__(self): # What is the length of the dataset
        return 1

In [5]:
def get_pseudo_labels(dataset, model, threshold = 0.65):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    #print(device)
    
    #Construct a data loader
    data_loader = DataLoader(dataset, batch_size = batch_size, shuffle = True, drop_last = True, num_workers = 8)
    
    #make sure the model is in eval mode
    model.eval()

    #Define softmax function
    softmax = nn.Softmax(dim = -1)
    #linear = nn.Linear(1000, 11).to(device)

    flag = 0
    pseudo_set = None
    #Iterate over the dataset by batches
    for batch in tqdm(data_loader):
        img, _ = batch

        #Forward the data
        with torch.no_grad():
            logits = model(img.to(device))
        
        probs = softmax(logits)
        #print(probs.shape)

        for i in range(len(probs)):
            #if model give the picture high confidence
            m = torch.max(probs[i])
            if m > threshold and flag == 1:
                label = (probs[i] == m).nonzero().item()
                new_labeled_data = pDataset(img[i], label)
                pseudo_set = ConcatDataset([pseudo_set, new_labeled_data])
            elif m > threshold and flag == 0:
                label = (probs[i] == m).nonzero().item()
                pseudo_set = pDataset(img[i], label)
                flag = 1

    model.train()
    return pseudo_set

In [6]:
import sys
sys.setrecursionlimit(100000)

In [7]:
import torchvision

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = torchvision.models.resnet18(weights='ResNet18_Weights.DEFAULT')
#linear = nn.Linear(1000, 11).to(device)

model.to(device)

#For the classification task, we use cross-entropy as the measurement of performance
criterion = nn.CrossEntropyLoss()

#Initialize optimizer, you may fine-tune some hyperparameters such as learning rate on your own
#optimizer = torch.optim.SGD(model.parameters(), lr = 0.0003, weight_decay = 1e-5, momentum = 0.9)
optimizer = torch.optim.RAdam(model.parameters(), lr = 0.0003)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 20, gamma=0.1, last_epoch=-1)

#the number of training epochs
#n_epochs = 120
n_epochs = 40

#Whether to do semi-supervised learning
do_semi = True

train_proc = []
eval_proc = []
best_acc = 0

#pseudo_set = None


for epoch in range(n_epochs):
    pseudo_set = None
    if do_semi:
        pseudo_set = get_pseudo_labels(unlabeled_set, model)
        #print(len(pseudo_set))

    if not pseudo_set == None:
        #print(len(train_set))
        concat_dataset = ConcatDataset([train_set, pseudo_set])
        #print(len(concat_dataset))
        train_loader = DataLoader(concat_dataset, batch_size = batch_size, shuffle = True, drop_last = True, num_workers = 8)
    else:
        train_loader = DataLoader(train_set, batch_size = batch_size, shuffle = True, drop_last = True, num_workers = 8)
        
    model.train()


    train_loss = []
    train_accs = []
    for batch in tqdm(train_loader):
        imgs, labels = batch
        imgs, labels = imgs.to(device), labels.to(device)
        logits = model(imgs)
        
        loss = criterion(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm = 10)
        optimizer.step()
        acc = (logits.argmax(dim = -1) == labels).float().mean()
        train_loss.append(loss.item())
        train_accs.append(acc)
    # The average loss and accuracy of the training set is the average of the recorded values.
    scheduler.step()
    train_loss = sum(train_loss) / len(train_loss)
    train_acc = sum(train_accs) / len(train_accs)
    train_proc.append(train_loss)

    # Print the information.
    print(f"[ Train | {epoch :03d}/{n_epochs:03d} ] loss = {train_loss:.5f}, acc = {train_acc:.5f}")

    # ---------- Validation ----------
    # Make sure the model is in eval mode so that some modules like dropout are disabled and work normally.
    model.eval()

    # These are used to record information in validation.
    valid_loss = []
    valid_accs = []

    # Iterate the validation set by batches.
    for batch in tqdm(valid_loader):
        # A batch consists of image data and corresponding labels.
        imgs, labels = batch
        imgs, labels = imgs.to(device), labels.to(device)

        # We don't need gradient in validation.
        # Using torch.no_grad() accelerates the forward process.
        with torch.no_grad():
            logits = model(imgs)
        # We can still compute the loss (but not the gradient).
        loss = criterion(logits, labels)

        # Compute the accuracy for current batch.
        acc = (logits.argmax(dim= -1) == labels.to(device)).float().mean()

        # Record the loss and accuracy.
        valid_loss.append(loss.item())
        valid_accs.append(acc)
    
    # The average loss and accuracy for entire validation set is the average of the recorded values.
    valid_loss = sum(valid_loss) / len(valid_loss)
    valid_acc = sum(valid_accs) / len(valid_accs)
    eval_proc.append(valid_loss)

    # Print the information.
    print(f"[ Valid | {epoch :03d}/{n_epochs:03d} ] loss = {valid_loss:.5f}, acc = {valid_acc:.5f}")
    
    if valid_acc > best_acc:
        best_acc = valid_acc
        torch.save(model.state_dict(), f'./pretrain/food11_{epoch}.pt')

100%|██████████| 53/53 [00:09<00:00,  5.87it/s]
100%|██████████| 28/28 [00:05<00:00,  4.82it/s]


[ Train | 000/040 ] loss = 7.73025, acc = 0.05831


100%|██████████| 5/5 [00:01<00:00,  3.25it/s]


[ Valid | 000/040 ] loss = 6.60579, acc = 0.05781


100%|██████████| 53/53 [00:08<00:00,  6.30it/s]
100%|██████████| 25/25 [00:05<00:00,  4.73it/s]


[ Train | 001/040 ] loss = 4.18513, acc = 0.32375


100%|██████████| 5/5 [00:01<00:00,  3.33it/s]


[ Valid | 001/040 ] loss = 2.69315, acc = 0.52031


100%|██████████| 53/53 [00:08<00:00,  6.50it/s]
100%|██████████| 42/42 [00:07<00:00,  5.50it/s]


[ Train | 002/040 ] loss = 0.92749, acc = 0.82366


100%|██████████| 5/5 [00:01<00:00,  3.48it/s]


[ Valid | 002/040 ] loss = 1.28295, acc = 0.73594


100%|██████████| 53/53 [00:08<00:00,  6.26it/s]
100%|██████████| 57/57 [00:09<00:00,  5.93it/s]


[ Train | 003/040 ] loss = 0.35643, acc = 0.91817


100%|██████████| 5/5 [00:01<00:00,  2.98it/s]


[ Valid | 003/040 ] loss = 0.91401, acc = 0.79375


100%|██████████| 53/53 [00:08<00:00,  5.99it/s]
100%|██████████| 65/65 [00:11<00:00,  5.78it/s]


[ Train | 004/040 ] loss = 0.22560, acc = 0.94519


100%|██████████| 5/5 [00:01<00:00,  3.39it/s]


[ Valid | 004/040 ] loss = 0.78562, acc = 0.80937


100%|██████████| 53/53 [00:08<00:00,  6.30it/s]
100%|██████████| 69/69 [00:11<00:00,  5.88it/s]


[ Train | 005/040 ] loss = 0.16279, acc = 0.95312


100%|██████████| 5/5 [00:01<00:00,  2.99it/s]


[ Valid | 005/040 ] loss = 0.67618, acc = 0.82969


100%|██████████| 53/53 [00:08<00:00,  5.99it/s]
100%|██████████| 71/71 [00:12<00:00,  5.90it/s]


[ Train | 006/040 ] loss = 0.14803, acc = 0.95456


100%|██████████| 5/5 [00:01<00:00,  2.83it/s]


[ Valid | 006/040 ] loss = 0.71030, acc = 0.82344


100%|██████████| 53/53 [00:08<00:00,  6.15it/s]
100%|██████████| 72/72 [00:12<00:00,  5.99it/s]


[ Train | 007/040 ] loss = 0.14574, acc = 0.95388


100%|██████████| 5/5 [00:01<00:00,  3.37it/s]


[ Valid | 007/040 ] loss = 0.73073, acc = 0.81875


100%|██████████| 53/53 [00:08<00:00,  6.04it/s]
100%|██████████| 72/72 [00:12<00:00,  5.91it/s]


[ Train | 008/040 ] loss = 0.14140, acc = 0.95519


100%|██████████| 5/5 [00:01<00:00,  3.07it/s]


[ Valid | 008/040 ] loss = 0.86490, acc = 0.81094


100%|██████████| 53/53 [00:08<00:00,  5.92it/s]
100%|██████████| 73/73 [00:12<00:00,  5.66it/s]


[ Train | 009/040 ] loss = 0.13507, acc = 0.95484


100%|██████████| 5/5 [00:01<00:00,  2.94it/s]


[ Valid | 009/040 ] loss = 0.80415, acc = 0.81875


100%|██████████| 53/53 [00:08<00:00,  6.40it/s]
100%|██████████| 73/73 [00:12<00:00,  5.91it/s]


[ Train | 010/040 ] loss = 0.14844, acc = 0.95238


100%|██████████| 5/5 [00:01<00:00,  3.09it/s]


[ Valid | 010/040 ] loss = 1.05613, acc = 0.79375


100%|██████████| 53/53 [00:08<00:00,  5.97it/s]
100%|██████████| 73/73 [00:12<00:00,  5.96it/s]


[ Train | 011/040 ] loss = 0.17198, acc = 0.94167


100%|██████████| 5/5 [00:01<00:00,  3.24it/s]


[ Valid | 011/040 ] loss = 1.35156, acc = 0.74687


100%|██████████| 53/53 [00:08<00:00,  6.27it/s]
100%|██████████| 73/73 [00:12<00:00,  5.95it/s]


[ Train | 012/040 ] loss = 0.22774, acc = 0.92551


100%|██████████| 5/5 [00:01<00:00,  3.23it/s]


[ Valid | 012/040 ] loss = 1.05965, acc = 0.77500


100%|██████████| 53/53 [00:08<00:00,  6.09it/s]
100%|██████████| 73/73 [00:12<00:00,  5.98it/s]


[ Train | 013/040 ] loss = 0.15368, acc = 0.94735


100%|██████████| 5/5 [00:01<00:00,  2.75it/s]


[ Valid | 013/040 ] loss = 0.90754, acc = 0.78125


100%|██████████| 53/53 [00:08<00:00,  6.34it/s]
100%|██████████| 72/72 [00:12<00:00,  5.91it/s]


[ Train | 014/040 ] loss = 0.14320, acc = 0.95182


100%|██████████| 5/5 [00:01<00:00,  3.03it/s]


[ Valid | 014/040 ] loss = 1.16884, acc = 0.76875


100%|██████████| 53/53 [00:08<00:00,  6.25it/s]
100%|██████████| 73/73 [00:12<00:00,  6.07it/s]


[ Train | 015/040 ] loss = 0.14705, acc = 0.95141


100%|██████████| 5/5 [00:01<00:00,  3.24it/s]


[ Valid | 015/040 ] loss = 1.02461, acc = 0.76875


100%|██████████| 53/53 [00:08<00:00,  6.24it/s]
100%|██████████| 73/73 [00:12<00:00,  5.96it/s]


[ Train | 016/040 ] loss = 0.12714, acc = 0.95441


100%|██████████| 5/5 [00:01<00:00,  3.24it/s]


[ Valid | 016/040 ] loss = 1.10332, acc = 0.77969


100%|██████████| 53/53 [00:08<00:00,  6.09it/s]
100%|██████████| 73/73 [00:12<00:00,  5.96it/s]


[ Train | 017/040 ] loss = 0.12398, acc = 0.95666


100%|██████████| 5/5 [00:01<00:00,  2.97it/s]


[ Valid | 017/040 ] loss = 0.93491, acc = 0.81719


100%|██████████| 53/53 [00:08<00:00,  6.03it/s]
100%|██████████| 74/74 [00:12<00:00,  5.98it/s]


[ Train | 018/040 ] loss = 0.11526, acc = 0.96136


100%|██████████| 5/5 [00:01<00:00,  3.28it/s]


[ Valid | 018/040 ] loss = 0.98553, acc = 0.78750


100%|██████████| 53/53 [00:08<00:00,  6.07it/s]
100%|██████████| 74/74 [00:12<00:00,  6.02it/s]


[ Train | 019/040 ] loss = 0.14351, acc = 0.95144


100%|██████████| 5/5 [00:01<00:00,  3.38it/s]


[ Valid | 019/040 ] loss = 0.92795, acc = 0.80313


100%|██████████| 53/53 [00:08<00:00,  6.13it/s]
100%|██████████| 74/74 [00:12<00:00,  6.01it/s]


[ Train | 020/040 ] loss = 0.06851, acc = 0.97614


100%|██████████| 5/5 [00:01<00:00,  2.94it/s]


[ Valid | 020/040 ] loss = 0.87374, acc = 0.81250


100%|██████████| 53/53 [00:08<00:00,  5.91it/s]
100%|██████████| 75/75 [00:12<00:00,  6.02it/s]


[ Train | 021/040 ] loss = 0.04393, acc = 0.98740


100%|██████████| 5/5 [00:01<00:00,  3.27it/s]


[ Valid | 021/040 ] loss = 0.82935, acc = 0.81875


100%|██████████| 53/53 [00:08<00:00,  6.24it/s]
100%|██████████| 75/75 [00:12<00:00,  5.97it/s]


[ Train | 022/040 ] loss = 0.03703, acc = 0.98979


100%|██████████| 5/5 [00:01<00:00,  3.39it/s]


[ Valid | 022/040 ] loss = 0.81263, acc = 0.82812


100%|██████████| 53/53 [00:08<00:00,  6.08it/s]
100%|██████████| 75/75 [00:12<00:00,  5.84it/s]


[ Train | 023/040 ] loss = 0.02984, acc = 0.99406


100%|██████████| 5/5 [00:01<00:00,  2.80it/s]


[ Valid | 023/040 ] loss = 0.80792, acc = 0.82969


100%|██████████| 53/53 [00:09<00:00,  5.63it/s]
100%|██████████| 75/75 [00:13<00:00,  5.72it/s]


[ Train | 024/040 ] loss = 0.02755, acc = 0.99385


100%|██████████| 5/5 [00:01<00:00,  3.00it/s]


[ Valid | 024/040 ] loss = 0.78946, acc = 0.83281


100%|██████████| 53/53 [00:08<00:00,  5.97it/s]
100%|██████████| 75/75 [00:12<00:00,  5.88it/s]


[ Train | 025/040 ] loss = 0.02846, acc = 0.99260


100%|██████████| 5/5 [00:01<00:00,  3.26it/s]


[ Valid | 025/040 ] loss = 0.81176, acc = 0.82500


100%|██████████| 53/53 [00:08<00:00,  6.14it/s]
100%|██████████| 75/75 [00:12<00:00,  5.94it/s]


[ Train | 026/040 ] loss = 0.02433, acc = 0.99469


100%|██████████| 5/5 [00:01<00:00,  3.01it/s]


[ Valid | 026/040 ] loss = 0.80762, acc = 0.82656


100%|██████████| 53/53 [00:08<00:00,  6.34it/s]
100%|██████████| 75/75 [00:12<00:00,  5.90it/s]


[ Train | 027/040 ] loss = 0.02436, acc = 0.99344


100%|██████████| 5/5 [00:01<00:00,  3.44it/s]


[ Valid | 027/040 ] loss = 0.76670, acc = 0.83281


100%|██████████| 53/53 [00:08<00:00,  6.33it/s]
100%|██████████| 75/75 [00:12<00:00,  5.91it/s]


[ Train | 028/040 ] loss = 0.02575, acc = 0.99354


100%|██████████| 5/5 [00:01<00:00,  3.18it/s]


[ Valid | 028/040 ] loss = 0.80168, acc = 0.83594


100%|██████████| 53/53 [00:08<00:00,  6.09it/s]
100%|██████████| 75/75 [00:12<00:00,  5.86it/s]


[ Train | 029/040 ] loss = 0.02229, acc = 0.99438


100%|██████████| 5/5 [00:01<00:00,  3.07it/s]


[ Valid | 029/040 ] loss = 0.81897, acc = 0.83281


100%|██████████| 53/53 [00:08<00:00,  6.29it/s]
100%|██████████| 76/76 [00:12<00:00,  5.97it/s]


[ Train | 030/040 ] loss = 0.02242, acc = 0.99383


100%|██████████| 5/5 [00:01<00:00,  3.34it/s]


[ Valid | 030/040 ] loss = 0.81449, acc = 0.82969


100%|██████████| 53/53 [00:08<00:00,  6.18it/s]
100%|██████████| 76/76 [00:12<00:00,  6.00it/s]


[ Train | 031/040 ] loss = 0.02213, acc = 0.99414


100%|██████████| 5/5 [00:01<00:00,  3.00it/s]


[ Valid | 031/040 ] loss = 0.86015, acc = 0.82969


100%|██████████| 53/53 [00:08<00:00,  6.26it/s]
100%|██████████| 76/76 [00:12<00:00,  6.05it/s]


[ Train | 032/040 ] loss = 0.02110, acc = 0.99414


100%|██████████| 5/5 [00:01<00:00,  3.29it/s]


[ Valid | 032/040 ] loss = 0.79593, acc = 0.84062


100%|██████████| 53/53 [00:08<00:00,  6.14it/s]
100%|██████████| 76/76 [00:12<00:00,  6.04it/s]


[ Train | 033/040 ] loss = 0.02147, acc = 0.99424


100%|██████████| 5/5 [00:01<00:00,  2.86it/s]


[ Valid | 033/040 ] loss = 0.76607, acc = 0.84219


100%|██████████| 53/53 [00:08<00:00,  6.19it/s]
100%|██████████| 76/76 [00:12<00:00,  5.92it/s]


[ Train | 034/040 ] loss = 0.01876, acc = 0.99579


100%|██████████| 5/5 [00:01<00:00,  3.26it/s]


[ Valid | 034/040 ] loss = 0.75671, acc = 0.84375


100%|██████████| 53/53 [00:08<00:00,  6.00it/s]
100%|██████████| 76/76 [00:12<00:00,  6.01it/s]


[ Train | 035/040 ] loss = 0.01923, acc = 0.99548


100%|██████████| 5/5 [00:01<00:00,  3.01it/s]


[ Valid | 035/040 ] loss = 0.80377, acc = 0.83438


100%|██████████| 53/53 [00:08<00:00,  6.01it/s]
100%|██████████| 76/76 [00:12<00:00,  5.95it/s]


[ Train | 036/040 ] loss = 0.01728, acc = 0.99650


100%|██████████| 5/5 [00:01<00:00,  3.33it/s]


[ Valid | 036/040 ] loss = 0.80969, acc = 0.83438


100%|██████████| 53/53 [00:08<00:00,  6.21it/s]
100%|██████████| 76/76 [00:12<00:00,  6.00it/s]


[ Train | 037/040 ] loss = 0.01763, acc = 0.99558


100%|██████████| 5/5 [00:01<00:00,  3.36it/s]


[ Valid | 037/040 ] loss = 0.81351, acc = 0.83438


100%|██████████| 53/53 [00:08<00:00,  6.23it/s]
100%|██████████| 76/76 [00:12<00:00,  5.95it/s]


[ Train | 038/040 ] loss = 0.01898, acc = 0.99548


100%|██████████| 5/5 [00:01<00:00,  3.11it/s]


[ Valid | 038/040 ] loss = 0.81898, acc = 0.83594


100%|██████████| 53/53 [00:08<00:00,  6.16it/s]
100%|██████████| 76/76 [00:12<00:00,  5.99it/s]


[ Train | 039/040 ] loss = 0.01695, acc = 0.99620


100%|██████████| 5/5 [00:01<00:00,  3.27it/s]

[ Valid | 039/040 ] loss = 0.84027, acc = 0.83125





In [8]:
import matplotlib.pyplot as plt


epochs = np.arange(n_epochs)
plt.figure(0)
plt.plot(epochs, np.array(train_proc), color = 'b', label = 'training loss')
plt.plot(epochs, np.array(eval_proc), color = 'r', label = 'Evaluation loss')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend()
plt.savefig(f'loss_surface.png')
plt.clf()
plt.show()

<Figure size 432x288 with 0 Axes>

In [11]:
model = torchvision.models.resnet18()
model.to(device)
checkpoint = torch.load('pretrain/food11_28.pt')
model.load_state_dict(checkpoint)


model.eval()

predictions = []

for batch in tqdm(test_loader):
    imgs, labels = batch
    softmax = nn.Softmax(dim = -1)

    with torch.no_grad():
        logits = model(imgs.to(device))
        #print(softmax(logits))

    predictions.extend(logits.argmax(dim = 1).cpu().numpy().tolist())

100%|██████████| 27/27 [00:16<00:00,  1.63it/s]


In [12]:
# Save predictions into the file.
with open("predict.csv", "w") as f:

    # The first row must be "Id, Category"
    f.write("Id,Category\n")

    # For the rest of the rows, each image id corresponds to a predicted class.
    for i, pred in  enumerate(predictions):
         f.write(f"{i},{pred}\n")