In [1]:
# Import necessary packages.
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
# "ConcatDataset" and "Subset" are possibly useful when doing semi-supervised learning.
from torch.utils.data import ConcatDataset, DataLoader, Subset, Dataset
from torchvision.datasets import DatasetFolder

# This is for the progress bar.
from tqdm.auto import tqdm

In [2]:
# It is important to do data augmentation in training.
# However, not every augmentation is useful.
# Please think about what kind of augmentation is helpful for food recognition.
train_tfm1 = transforms.Compose([
    # Resize the image into a fixed shape (height = width = 128)
    transforms.Resize((128, 128)),
    # You may add some transforms here.
    # ToTensor() should be the last one of the transforms.
    transforms.ToTensor(),
])

train_tfm2 = transforms.Compose([
    # Resize the image into a fixed shape (height = width = 128)
    transforms.Resize((128, 128)),
    # You may add some transforms here.
    # ToTensor() should be the last one of the transforms.
    transforms.RandomHorizontalFlip(p=0.6),
    transforms.ColorJitter(brightness=0.5),
    transforms.RandomRotation(30),
    transforms.ToTensor(),
])



# We don't need augmentations in testing and validation.
# All we need here is to resize the PIL image and transform it into Tensor.
test_tfm = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])


In [3]:
# Batch size for training, validation, and testing.
# A greater batch size usually gives a more stable gradient.
# But the GPU memory is limited, so please adjust it carefully.
batch_size = 128

# Construct datasets.
# The argument "loader" tells how torchvision reads the data.
train_set1 = DatasetFolder("../input/ml2021spring-hw3/food-11/training/labeled", loader=lambda x: Image.open(x), extensions="jpg", transform=train_tfm1)
train_set2 = DatasetFolder("../input/ml2021spring-hw3/food-11/training/labeled", loader=lambda x: Image.open(x), extensions="jpg", transform=train_tfm2)
train_set = ConcatDataset([train_set1, train_set2])
valid_set1 = DatasetFolder("../input/ml2021spring-hw3/food-11/validation", loader=lambda x: Image.open(x), extensions="jpg", transform=train_tfm1)
valid_set2 = DatasetFolder("../input/ml2021spring-hw3/food-11/validation", loader=lambda x: Image.open(x), extensions="jpg", transform=train_tfm2)
valid_set = ConcatDataset([valid_set1, valid_set2])
unlabeled_set = DatasetFolder("../input/ml2021spring-hw3/food-11/training/unlabeled", loader=lambda x: Image.open(x), extensions="jpg", transform=train_tfm1)
test_set = DatasetFolder("../input/ml2021spring-hw3/food-11/testing", loader=lambda x: Image.open(x), extensions="jpg", transform=test_tfm)

# Construct data loaders.
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, pin_memory=True)
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=True, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

In [9]:
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        # The arguments for commonly used modules:
        # torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        # torch.nn.MaxPool2d(kernel_size, stride, padding)

        # input image size: [3, 128, 128]
        self.cnn_layers = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),

            nn.Conv2d(64, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),

            nn.Conv2d(128, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),

            nn.Conv2d(256, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),

            nn.Conv2d(512, 1024, 3, 1, 1),
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),
            
            nn.Conv2d(1024, 2048, 3, 1, 1),
            nn.BatchNorm2d(2048),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),
        )
        self.fc_layers = nn.Sequential(
            nn.Linear(8192, 2048),
            nn.BatchNorm1d(2048),
            nn.ReLU(),
            nn.Dropout(0.55),
            nn.Linear(2048, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(256, 11)
        )

    def forward(self, x):
        # input (x): [batch_size, 3, 128, 128]
        # output: [batch_size, 11]

        # Extract features by convolutional layers.
        x = self.cnn_layers(x)

        # The extracted feature map must be flatten before going to fully-connected layers.
        x = x.flatten(1)

        # The features are transformed by fully-connected layers to obtain the final logits.
        x = self.fc_layers(x)
        return x

In [5]:
class PseudoDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.y)

    def __getitem__(self, id):
        return self.x[id][0], self.y[id]

def get_pseudo_labels(dataset, model, threshold=0.92):
    device = "cuda" if torch.cuda.is_available() else "cpu"

    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    model.eval()
    softmax = nn.Softmax(dim=-1)

    idx = []
    labels = []

    for i, batch in enumerate(data_loader):
        img, _ = batch
        with torch.no_grad():
            logits = model(img.to(device))
        probs = softmax(logits)

        for j, x in enumerate(probs):
            if torch.max(x) > threshold:
                idx.append(i * batch_size + j)
                labels.append(int(torch.argmax(x)))

    model.train()
    print ("\nNew data: {:5d}\n".format(len(idx)))
    dataset = PseudoDataset(Subset(dataset, idx), labels)
    return dataset



In [None]:
# "cuda" only when GPUs are available.
device = "cuda" if torch.cuda.is_available() else "cpu"

# Initialize a model, and put it on the device specified.
model = Classifier().to(device)
model.device = device

# For the classification task, we use cross-entropy as the measurement of performance.
criterion = nn.CrossEntropyLoss()

# Initialize optimizer, you may fine-tune some hyperparameters such as learning rate on your own.
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5)

# The number of training epochs.
n_epochs = 400

# Whether to do semi-supervised learning.
do_semi = True

# record information
best_acc = 0.0
train_loss_record = []
valid_loss_record = []
train_acc_record = []
valid_acc_record = []
model_path = "model.ckpt"

for epoch in range(n_epochs):
    # ---------- TODO ----------
    # In each epoch, relabel the unlabeled dataset for semi-supervised learning.
    # Then you can combine the labeled dataset and pseudo-labeled dataset for the training.
    if do_semi and best_acc > 0.6 and epoch % 4 == 0:
        # Obtain pseudo-labels for unlabeled data using trained model.
        pseudo_set = get_pseudo_labels(unlabeled_set, model)
        concat_dataset = ConcatDataset([train_set, pseudo_set])
        train_loader = DataLoader(concat_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, drop_last=True)


    # ---------- Training ----------
    # Make sure the model is in train mode before training.
    model.train()

    # These are used to record information in training.
    train_loss = []
    train_accs = []

    # Iterate the training set by batches.
    for batch in tqdm(train_loader):

        # A batch consists of image data and corresponding labels.
        imgs, labels = batch

        # Forward the data. (Make sure data and model are on the same device.)
        logits = model(imgs.to(device))

        # Calculate the cross-entropy loss.
        # We don't need to apply softmax before computing cross-entropy as it is done automatically.
        loss = criterion(logits, labels.to(device))

        # Gradients stored in the parameters in the previous step should be cleared out first.
        optimizer.zero_grad()

        # Compute the gradients for parameters.
        loss.backward()

        # Clip the gradient norms for stable training.
        grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)

        # Update the parameters with computed gradients.
        optimizer.step()

        # Compute the accuracy for current batch.
        acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()

        # Record the loss and accuracy.
        train_loss.append(loss.item())
        train_accs.append(acc)

    # The average loss and accuracy of the training set is the average of the recorded values.
    train_loss = sum(train_loss) / len(train_loss)
    train_acc = sum(train_accs) / len(train_accs)

    # Print the information.
    print(f"[ Train | {epoch + 1:03d}/{n_epochs:03d} ] loss = {train_loss:.5f}, acc = {train_acc:.5f}")

    # ---------- Validation ----------
    # Make sure the model is in eval mode so that some modules like dropout are disabled and work normally.
    model.eval()

    # These are used to record information in validation.
    valid_loss = []
    valid_accs = []

    # Iterate the validation set by batches.
    for batch in tqdm(valid_loader):

        # A batch consists of image data and corresponding labels.
        imgs, labels = batch

        # We don't need gradient in validation.
        # Using torch.no_grad() accelerates the forward process.
        with torch.no_grad():
          logits = model(imgs.to(device))

        # We can still compute the loss (but not the gradient).
        loss = criterion(logits, labels.to(device))

        # Compute the accuracy for current batch.
        acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()

        # Record the loss and accuracy.
        valid_loss.append(loss.item())
        valid_accs.append(acc)

    # The average loss and accuracy for entire validation set is the average of the recorded values.
    valid_loss = sum(valid_loss) / len(valid_loss)
    valid_acc = sum(valid_accs) / len(valid_accs)
    # Print the information.
    print(f"[ Valid | {epoch + 1:03d}/{n_epochs:03d} ] loss = {valid_loss:.5f}, acc = {valid_acc:.5f}")
    # ------Record every time information------
    if valid_acc > best_acc:
        best_acc = valid_acc
        torch.save(model.state_dict(), model_path)
    train_loss_record.append(train_loss)
    valid_loss_record.append(valid_loss)
    train_acc_record.append(train_acc)
    valid_acc_record.append(valid_acc)

  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 001/400 ] loss = 2.06295, acc = 0.27328


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 001/400 ] loss = 2.69077, acc = 0.14957


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 002/400 ] loss = 1.70885, acc = 0.42060


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 002/400 ] loss = 1.71634, acc = 0.41051


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 003/400 ] loss = 1.49649, acc = 0.50478


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 003/400 ] loss = 1.73954, acc = 0.41804


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 004/400 ] loss = 1.32306, acc = 0.56585


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 004/400 ] loss = 1.57802, acc = 0.44545


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 005/400 ] loss = 1.18333, acc = 0.61974


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 005/400 ] loss = 1.89326, acc = 0.37912


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 006/400 ] loss = 1.04108, acc = 0.68208


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 006/400 ] loss = 1.52982, acc = 0.49048


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 007/400 ] loss = 0.86867, acc = 0.74267


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 007/400 ] loss = 1.52319, acc = 0.49432


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 008/400 ] loss = 0.75863, acc = 0.78029


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 008/400 ] loss = 1.42706, acc = 0.52898


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 009/400 ] loss = 0.66270, acc = 0.81218


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 009/400 ] loss = 1.48547, acc = 0.50284


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 010/400 ] loss = 0.57742, acc = 0.83673


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 010/400 ] loss = 1.64654, acc = 0.47628


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 011/400 ] loss = 0.52729, acc = 0.85188


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 011/400 ] loss = 1.79699, acc = 0.47713


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 012/400 ] loss = 0.48142, acc = 0.86639


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 012/400 ] loss = 1.45682, acc = 0.52955


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 013/400 ] loss = 0.41383, acc = 0.88473


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 013/400 ] loss = 1.40729, acc = 0.54858


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 014/400 ] loss = 0.39310, acc = 0.89158


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 014/400 ] loss = 1.44135, acc = 0.56236


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 015/400 ] loss = 0.35887, acc = 0.89971


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 015/400 ] loss = 1.76300, acc = 0.49602


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 016/400 ] loss = 0.33650, acc = 0.90880


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 016/400 ] loss = 1.70985, acc = 0.48381


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 017/400 ] loss = 0.31871, acc = 0.90880


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 017/400 ] loss = 1.50963, acc = 0.53366


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 018/400 ] loss = 0.28615, acc = 0.91741


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 018/400 ] loss = 1.63012, acc = 0.52188


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 019/400 ] loss = 0.27352, acc = 0.92459


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 019/400 ] loss = 1.56749, acc = 0.52969


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 020/400 ] loss = 0.24732, acc = 0.93383


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 020/400 ] loss = 1.53986, acc = 0.52386


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 021/400 ] loss = 0.23490, acc = 0.93941


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 021/400 ] loss = 1.55330, acc = 0.54730


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 022/400 ] loss = 0.20241, acc = 0.94723


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 022/400 ] loss = 1.58728, acc = 0.52060


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 023/400 ] loss = 0.20790, acc = 0.94356


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 023/400 ] loss = 1.56272, acc = 0.54361


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 024/400 ] loss = 0.19120, acc = 0.95073


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 024/400 ] loss = 1.46202, acc = 0.56634


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 025/400 ] loss = 0.15457, acc = 0.95743


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 025/400 ] loss = 1.72722, acc = 0.52827


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 026/400 ] loss = 0.16954, acc = 0.95281


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 026/400 ] loss = 1.60970, acc = 0.55071


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 027/400 ] loss = 0.14386, acc = 0.96110


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 027/400 ] loss = 1.47465, acc = 0.56151


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 028/400 ] loss = 0.14683, acc = 0.95982


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 028/400 ] loss = 1.64394, acc = 0.53182


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 029/400 ] loss = 0.14031, acc = 0.96540


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 029/400 ] loss = 1.57344, acc = 0.55923


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 030/400 ] loss = 0.12300, acc = 0.96811


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 030/400 ] loss = 1.58246, acc = 0.54290


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 031/400 ] loss = 0.13777, acc = 0.96269


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 031/400 ] loss = 1.74070, acc = 0.52457


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 032/400 ] loss = 0.14834, acc = 0.95839


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 032/400 ] loss = 1.83372, acc = 0.51548


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 033/400 ] loss = 0.12815, acc = 0.96763


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 033/400 ] loss = 1.54533, acc = 0.56776


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 034/400 ] loss = 0.10339, acc = 0.97417


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 034/400 ] loss = 1.57002, acc = 0.56080


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 035/400 ] loss = 0.09713, acc = 0.97465


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 035/400 ] loss = 1.55831, acc = 0.56477


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 036/400 ] loss = 0.09125, acc = 0.97720


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 036/400 ] loss = 1.62372, acc = 0.54702


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 037/400 ] loss = 0.11049, acc = 0.97066


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 037/400 ] loss = 1.57893, acc = 0.54432


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 038/400 ] loss = 0.10355, acc = 0.97146


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 038/400 ] loss = 1.55098, acc = 0.57315


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 039/400 ] loss = 0.10175, acc = 0.97194


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 039/400 ] loss = 1.66985, acc = 0.55440


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 040/400 ] loss = 0.10361, acc = 0.96987


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 040/400 ] loss = 1.80188, acc = 0.52926


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 041/400 ] loss = 0.09601, acc = 0.97529


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 041/400 ] loss = 1.66365, acc = 0.53409


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 042/400 ] loss = 0.08442, acc = 0.97529


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 042/400 ] loss = 1.70910, acc = 0.55270


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 043/400 ] loss = 0.09762, acc = 0.97305


  0%|          | 0/11 [00:00<?, ?it/s]

[ Train | 044/400 ] loss = 0.07579, acc = 0.97991


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 044/400 ] loss = 1.57388, acc = 0.56278


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 045/400 ] loss = 0.07846, acc = 0.97895


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 045/400 ] loss = 1.71833, acc = 0.55313


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 046/400 ] loss = 0.08790, acc = 0.97337


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 046/400 ] loss = 1.67747, acc = 0.55185


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 047/400 ] loss = 0.06795, acc = 0.97911


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 047/400 ] loss = 1.62068, acc = 0.56193


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 048/400 ] loss = 0.06618, acc = 0.98469


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 048/400 ] loss = 1.59940, acc = 0.58068


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 049/400 ] loss = 0.05688, acc = 0.98709


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 049/400 ] loss = 1.54603, acc = 0.57770


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 051/400 ] loss = 0.06935, acc = 0.98246


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 051/400 ] loss = 1.96449, acc = 0.52884


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 052/400 ] loss = 0.07735, acc = 0.97688


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 052/400 ] loss = 1.75323, acc = 0.54531


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 053/400 ] loss = 0.06596, acc = 0.98023


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 053/400 ] loss = 1.68220, acc = 0.56094


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 054/400 ] loss = 0.06012, acc = 0.98182


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 054/400 ] loss = 1.71034, acc = 0.55923


  0%|          | 0/49 [00:00<?, ?it/s]

[ Train | 056/400 ] loss = 0.07510, acc = 0.97784


  0%|          | 0/11 [00:00<?, ?it/s]

[ Valid | 056/400 ] loss = 1.91511, acc = 0.55057


  0%|          | 0/49 [00:00<?, ?it/s]

In [None]:
import matplotlib.pyplot as plt

x = np.arange(len(train_acc_record))
plt.plot(x, train_acc_record, color="blue", label="Train")
plt.plot(x, valid_acc_record, color="red", label="Valid")
plt.legend(loc="upper right")
plt.show()

In [None]:
import matplotlib.pyplot as plt

x = np.arange(len(train_loss_record))
plt.plot(x, train_loss_record, color="blue", label="Train")
plt.plot(x, valid_loss_record, color="red", label="Valid")
plt.legend(loc="upper right") 
plt.show()