# PyTorch Custom Datasets 

### Import PyTorch setup device agnostic code.

In [None]:
import torch
from torch import nn

print(torch.__version__)

In [None]:
device = "cuda" if torch.cuda.is_available() else 'cpu'
device

# Data

dataset is a subset of Food101 dataset


In [None]:
import requests
import zipfile
from pathlib import Path

data_path = Path("data/")
image_path = data_path / "pizza_steak_sushi"

if image_path.is_dir():
    print(f"{image_path} directory already exists...skipping download")
else:
    print(f"{image_path} does not exist, creating one...")
    image_path.mkdir(parents=True, exist_ok=True)

with open(data_path / "pizza_steak_sushi.zip", "wb") as f:
    request = requests.get("https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip")
    print('downloading pizza, steak, and sushi data')
    f.write(request.content)

with zipfile.ZipFile(data_path / "pizza_steak_sushi.zip", "r") as zip_ref:
    print("unzipping...")
    zip_ref.extractall(image_path)

# Exploring data

In [None]:
import os
def walk_through_dir(dir_path):
    for dirpath, dirnames, filenames in os.walk(dir_path):
        print(f"There are {len(dirnames)} directories and {len(filenames)} images in '{dirpath}'.")

In [None]:
walk_through_dir(image_path)

In [None]:
train_dir = image_path / "train"
test_dir = image_path / "test"

train_dir, test_dir

# Visualizing images

1. get all image path
2. pick random image: random.choice()
3. get the image class name using `pathlib.Path.parent.stem`
4. open image with Pillow
5. show image meta data

In [None]:
import random
from PIL import Image

# random.seed(42)

image_path_list = list(image_path.glob("*/*/*.jpg"))
image_path_list

random_image_path = random.choice(image_path_list)
random_image_path

image_class = random_image_path.parent.stem
image_class

img = Image.open(random_image_path)

print(f"random_image_path: {random_image_path}")
print(f"Image class: {image_class}")
print(f"Image height: {img.height}")
print(f"Image width: {img.width}")
img

In [None]:
import numpy as np
import matplotlib.pyplot as plt

img_as_array = np.asarray(img)

plt.figure(figsize=(10,7))
plt.imshow(img_as_array)
plt.title(f"Image class: {image_class} | Image shape: {img_as_array.shape} -> [height, width, color_channels]")
plt.axis(False)

In [None]:
img_as_array

# Data Transformation

1. data to tensors
2. dataset `torch.utils.data.Dataset` -> DataLoader `torch.utils.data.DataLoader`

In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

### transform data with `torchvision.transforms`

In [None]:
# create transorm for an image
data_transform = transforms.Compose([
    # Resize to 64x64
    transforms.Resize(size=(64,64)),
    # Flip the images randomly on horizontal
    transforms.RandomHorizontalFlip(p=0.5),
    # Turn image into torch tensor
    transforms.ToTensor()
])

In [None]:
data_transform(img).shape

In [None]:
def plot_transformed_images(image_paths: list, transform, n=3, seed=None):
    """
      Returns random images from an image patch, loads/transforms and plots original vs transformed
    """
    if seed:
      random.seed(seed)
    random_image_paths = random.sample(image_paths, k=n)
    for image_path in random_image_paths:
       with Image.open(image_path) as f:
          fig, ax = plt.subplots(nrows=1, ncols=2)
          ax[0].imshow(f)
          ax[0].set_title(f"Original\nSize:{f.size}")
          ax[0].axis(False)

          # Transform and plot target image
          transformed_image = transform(f).permute(1, 2, 0) # need to change shape for matplotlib
          ax[1].imshow(transformed_image)
          ax[1].set_title(f"Transformed\nSize: {transformed_image.shape}")
          ax[1].axis("off")

          fig.suptitle(f"Class: {image_path.parent.stem}", fontsize=16)

plot_transformed_images(image_paths=image_path_list,
                        transform=data_transform,
                        n=3,
                        seed=42) 

# loading image data from image folder

In [None]:
from torchvision import datasets
train_data = datasets.ImageFolder(root=train_dir,
                                  transform=data_transform, # a transorm for the data
                                  target_transform=None) # a transform the label/target

test_data = datasets.ImageFolder(root=test_dir,
                                 transform=data_transform)

train_data,test_data

In [None]:
# class names as list
class_names = train_data.classes
class_names

In [None]:
# get class names as a dict
class_dict = train_data.class_to_idx
class_dict

In [None]:
# len of dataset
len(train_data), len(test_data)

In [None]:
# Index train_data dataset to get single image and label
img, label = train_data[0][0], train_data[0][1]

In [None]:
print(f"Image Tensor:\n {img}")
print(f"Image shape: {img.shape}")
print(f"Image datatype: {img.dtype}")
print(f"Image label: {label}")
print(f"Label datatype: {type(label)}")

In [None]:
# Rearrange order of dimensions
img_permute = img.permute(1,2,0)

# print shapes
print(f"Original shape: {img.shape} -> [color_channels, height, width]")
print(f"Image Permute: {img_permute.shape} -> [height, width, color_channel]")

#plt image
plt.figure(figsize=(10,7))
plt.imshow(img_permute)
plt.axis("off")
plt.title(class_names[label], fontsize=14)

# datalaoders

In [None]:
from torch.utils.data import DataLoader
BATCH_SIZE = 32

train_dataloader = DataLoader(dataset=train_data,
                              batch_size=BATCH_SIZE,
                              num_workers=2,
                              shuffle=True)
test_dataloader = DataLoader(dataset=test_data,
                             batch_size=BATCH_SIZE,
                             num_workers=2,
                             shuffle=False)

train_dataloader, test_dataloader

In [None]:
img, label = next(iter(train_dataloader))

print(f"Image Shape: {img.shape}")
print(f"Label shape: {label.shape}")

# Option 2: Custom dataloader function

1. load images from files
2. get class names from Dataset
3. get clases as dictionary from dataset

Pros:
* can create a `Dataset` out of almost anything
* Not limited to PyTorch pre-built `Dataset` functions

Cons:
* cant create a `Dataset` from almost anything, does not mean it will work
* using a custom `Dataset` often results in writing more code, which is prone to errors or performance issues


All custom dataset in PyTorch often subclass `torch.utils.data.Dataset`

In [None]:
import os
import pathlib 
import torch

from PIL import Image
from torch.utils.data import Dataset 
from torchvision import transforms
from typing import Tuple, Dict, List

In [None]:
train_data.classes, train_data.class_to_idx

# Creating a helper function to get class names

1. get the class names using os.scandir()
2. raise an error if class names arent found 
3. turn class names into a dict and list and return them

In [None]:
target_directory = train_dir
print(f"Target dir: {target_directory}")

class_names_found = sorted([entry.name for entry in list(os.scandir(target_directory))])
class_names_found

In [None]:
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
    """ Finds the clas folder names in a target directory"""
    classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())

    if not classes:
        raise FileNotFoundError(f"Couldn't find any classes in {directory}...please check file structure")
    
    class_to_idx = {class_name: i for i, class_name in enumerate(classes)}

    return classes, class_to_idx

In [None]:
find_classes(target_directory)

# Create custom `dataset` to replicate `ImageFolder`

Creating custom dataset
1. Subclass `torch.utils.data.Dataset`
2. Init our subclas with target directory
3. Create several attributes:
    * paths - paths of images
    * transform - the transform used
    * classes - list of target classes
    * class_to_idx - a dict of target classes mapped to int labels
4. Create function to `load_images()`
5. Overwrite the `__len()__` method to return length of dataset
6. Overwrite the `__getitem()` method to return given sample when passed an index

In [None]:
# Custom dataset
from torch.utils.data import Dataset 

#1. Subclass 
class ImageFolderCustom(Dataset):
    #2. Init custom dataset
    def __init__(self,
            targ_dir: str,
            transform=None) -> None:
        #3. Create class atrributes
        # gather all image paths
        self.paths = list(pathlib.Path(targ_dir).glob("*/*.jpg"))
        # transform
        self.transform = transform
        # create classes and class_to_idx
        self.classes, self.class_to_idx = find_classes(targ_dir)
    
    # 4. Create a func to load images
    def load_images(self, index: int) -> Image.Image:
        """
        Opens an image via a path and return it.
        """
        image_path = self.paths[index]
        return Image.open(image_path)
    
    #5. overwrite __len__()
    def __len__(self) -> int:
        "return total number of samples."
        return len(self.paths)
    
    #6. overwrite __getitem()__
    def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
        " returns on sample of data, data and label (X, y)"
        img = self.load_images(index)
        class_name = self.paths[index].parent.name # expect path in format: data_folder/class_name/image.jpg
        class_idx = self.class_to_idx[class_name]

        # transform if needed
        if self.transform:
            return self.transform(img), class_idx # return data, label (X, y)
        return img, class_idx # return image and label

In [None]:
# Create a transform
train_tranforms = transforms.Compose([
    transforms.Resize(size=(64,64)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor()
])

test_tranforms = transforms.Compose([
    transforms.Resize(size=(64,64)),
    transforms.ToTensor()
])

In [None]:
# test imagefoldercustom
train_data_custom = ImageFolderCustom(targ_dir=train_dir,
                                      transform=train_tranforms)

test_data_custom = ImageFolderCustom(targ_dir=test_dir,
                                     transform=test_tranforms)

In [None]:
train_data_custom, test_data_custom

In [None]:
len(train_data), len(train_data_custom)

In [None]:
len(test_data), len(test_data_custom)

In [None]:
train_data_custom.classes, train_data_custom.class_to_idx

In [None]:
# check for equality between original and custom
print(train_data_custom.classes==train_data.classes)

In [None]:
print(test_data_custom.classes==test_data.classes)

## Create a function to display random images

1. take in a `Dataset` and other params
2. cap number of images to 10
3. set random seed
4. list of random samples from the target dataset
5. Setup a matplotlib plot
6. Loop random samples image and plot with matplotlib
7. make sure dims work with matplotlib

In [None]:
# 1. create func to take in a dataset
def display_random_images(dataset: torch.utils.data.Dataset,
                          classes: List[str] = None,
                          n: int = 10, 
                          display_shape: bool = True,
                          seed: int = None):
    #2.  adjust display if n too high
    if n > 10:
        n = 10
        display_shape = False 
        print(f"For display, purposes, n shouldn't be larget thatn 10, setting display shape to false")
    # 3. Set the seed
    if seed: 
        random.seed(seed)
    # 4. get random samples indexes
    random_samples_idx = random.sample(range(len(dataset)), k=n)

    # 5. setup plot
    plt.figure(figsize=(16,8))

    # 6. loop random sample images
    for i, targ_sample in enumerate(random_samples_idx):
        targ_image, targ_label = dataset[targ_sample][0], dataset[targ_sample][1]

        # 7. correct dims
        targ_image_adjust = targ_image.permute(1,2,0) # [color_channels, h, w] -> [h, w, color_channels]

        # plot
        plt.subplot(1, n, i+1)
        plt.imshow(targ_image_adjust)
        plt.axis("off")
        if classes:
            title = f"Class: {classes[targ_label]}"
            if display_shape:
                title = title + f"\nshape: {targ_image_adjust.shape}"
        plt.title(title)

In [None]:
# display random images from built in
display_random_images(train_data,
                      n=5,
                      classes=class_names,
                      seed=42)

In [None]:
# display random images from custom
display_random_images(train_data_custom,n=5,
                      classes=class_names,
                      seed=42)

# Custom loaded images into a DataLoader

In [None]:
from torch.utils.data import DataLoader

BATCH_SIZE = 32
NUM_WORKERS = os.cpu_count()
train_dataloader_custom = DataLoader(dataset=train_data_custom,
                                     batch_size=BATCH_SIZE,
                                     num_workers=NUM_WORKERS,
                                     shuffle=True)

test_dataloader_custom = DataLoader(dataset=test_data_custom,
                                     batch_size=BATCH_SIZE,
                                     num_workers=NUM_WORKERS)

train_dataloader_custom, test_dataloader_custom

# Get image and label from custom dataloader

In [None]:
img_custom, label_custom = next(iter(train_data_custom))
img_custom.shape, label_custom

# Data Augmentation

In [None]:
# trivial augmentation

from torchvision import transforms

train_transform = transforms.Compose([
    transforms.Resize(size=(224, 224)),
    transforms.TrivialAugmentWide(num_magnitude_bins=31),
    transforms.ToTensor()
])

test_transform = transforms.Compose([
    transforms.Resize(size=(224, 224)),
    transforms.ToTensor()
])

In [None]:
image_path_list = list(image_path.glob("*/*/*.jpg"))
image_path_list[:10]

In [None]:
plot_transformed_images(
    image_paths=image_path_list,
    transform=train_tranforms,
    n=3,
    seed=None
)

# Model 0: TinyVGG

### transforms and load data

In [None]:
simple_transform = transforms.Compose([
    transforms.Resize(size=(64,64)),
    transforms.ToTensor()
])

In [None]:
# 1. load and transform dataset
from torchvision import datasets

train_data_simple = datasets.ImageFolder(root=train_dir,
                                         transform=simple_transform)

test_data_simple = datasets.ImageFolder(root=test_dir,
                                         transform=simple_transform)

In [None]:
# dataset into datalaoder
import os
from torch.utils.data import DataLoader

BATCH_SIZE = 32
NUM_WORKERS = os.cpu_count()
print(f"Creating DataLoader's with batch size {BATCH_SIZE} and {NUM_WORKERS} workers.")

train_dataloader_simple = DataLoader(dataset=train_data_simple,
                                     batch_size=BATCH_SIZE,
                                     shuffle=True,
                                     num_workers=NUM_WORKERS)

test_dataloader_simple = DataLoader(dataset=test_data_simple,
                                     batch_size=BATCH_SIZE,
                                     shuffle=False,
                                     num_workers=NUM_WORKERS)

### TinyVGG model Class

In [None]:
class TinyVGG(nn.Module):
    """
      Model architecture copy TinyVGG from CNN Exapliner
    """
    def __init__(self, input_shape: int,
                 hidden_units: int,
                 output_shape: int) -> None:
        super().__init__()
        self.conv_block_1 = nn.Sequential(
            nn.Conv2d(in_channels=input_shape,
                      out_channels=hidden_units,
                      kernel_size=3,
                      stride=1,
                      padding=0),
            nn.ReLU(),
            nn.Conv2d(in_channels=hidden_units,
                      out_channels=hidden_units,
                      kernel_size=3,
                      stride=1,
                      padding=0),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2,
                         stride=2) # default stride for maxPool2d is same as kernal size                    
        )
        self.conv_block_2 = nn.Sequential(
            nn.Conv2d(in_channels=hidden_units,
                      out_channels=hidden_units,
                      kernel_size=3,
                      stride=1,
                      padding=0),
            nn.ReLU(),
            nn.Conv2d(in_channels=hidden_units,
                      out_channels=hidden_units,
                      kernel_size=3,
                      stride=1,
                      padding=0),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2,
                         stride=2) # default stride for maxPool2d is same as kernal size                    
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=hidden_units*13*13,
                      out_features=output_shape)
        )
    def forward(self, x): 
        x = self.conv_block_1(x)
        # print(x.shape)
        x = self.conv_block_2(x)
        # print(x.shape)
        x = self.classifier(x)
        return x
        # return self.classifier(self.conv_block_2(self.conv_block_1(x))) # benefits from operator fusion

In [None]:
torch.manual_seed(42)
model_0 = TinyVGG(input_shape=3, hidden_units=10, output_shape=len(class_names)).to(device)
model_0

### forward pass on single image

In [None]:
image_batch, label_batch = next(iter(train_data_simple))
image_batch.shape, label_batch

In [None]:
# model_0(image_batch.to(device))

# `torchinfo`

In [None]:
try: 
    import torchinfo
except:
    !pip install torchinfo
    import torchinfo

from torchinfo import summary
summary(model_0, input_size=[1,3,64,64])

# Training TinyVGG

* `train_step()` - takes model and dataloader then trains the model
* `test_step()` - takes model and dataloader the evals the model

In [None]:
def train_step(model: torch.nn.Module,
               dataloader: torch.utils.data.DataLoader,
               loss_fn: torch.nn.Module,
               optimizer: torch.optim.Optimizer,
               device=device):
    # put model in training mode
    model.train()

    # setup train loss and accuracy values
    train_loss, train_acc = 0, 0

    # loop through data loader batches
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # 1. Forward Pass
        y_pred = model(X)

        # 2. Calculate the loss
        loss = loss_fn(y_pred, y)
        train_loss += loss.item()

        # 3. Optimizer zero grad
        optimizer.zero_grad()

        # 4. Loss backward        
        loss.backward()

        # 5. Optimizer step
        optimizer.step()

        # Calculate accuracy metric
        y_pred_class = torch.argmax(torch.softmax(y_pred, dim=1), dim=1)
        train_acc += (y_pred_class==y).sum().item()/len(y_pred)

    # Adjust metrics to get avg loss and acc per batch
    train_loss = train_loss / len(dataloader)
    train_acc = train_acc / len(dataloader)
    return train_loss, train_acc


In [None]:
def test_step(model: torch.nn.Module,
               dataloader: torch.utils.data.DataLoader,
               loss_fn: torch.nn.Module,
               device=device):
    # put model in training mode
    model.eval()

    # setup train loss and accuracy values
    test_loss, test_acc = 0, 0

    # loop through data loader batches
    with torch.inference_mode():
      for batch, (X, y) in enumerate(dataloader):
          X, y = X.to(device), y.to(device)

          # 1. Forward Pass
          test_pred_logits = model(X)

          # 2. Calculate the loss
          loss = loss_fn(test_pred_logits, y)
          test_loss += loss.item()

          # Calculate accuracy metric
          test_pred_labels = test_pred_logits.argmax(dim=1)
          test_acc += (test_pred_labels==y).sum().item()/len(test_pred_labels)

    # Adjust metrics to get avg loss and acc per batch
    test_loss = test_loss / len(dataloader)
    test_acc = test_acc / len(dataloader)
    return test_loss, test_acc


In [None]:
from tqdm.auto import tqdm

# 1. Create train function that takes in various model params & optimzer
def train(model: torch.nn.Module,
          train_dataloader: torch.utils.data.DataLoader,
          test_dataloader: torch.utils.data.DataLoader,
          optimizer: torch.optim.Optimizer,
          loss_fn: torch.nn.Module = nn.CrossEntropyLoss(),
          epochs: int = 5,
          device=device):
    
    # 2. Create emtpy results dict
    results = {"train_loss": [],
               "train_acc": [],
               "test_loss": [],
               "test_acc": []}
    
    # 3. Loop through training and testing steps
    for epoch in tqdm(range(epochs)):
        train_loss, train_acc = train_step(model=model, 
                                           dataloader=train_dataloader,
                                           loss_fn=loss_fn,
                                           optimizer=optimizer,
                                           device=device)
        test_loss, test_acc = test_step(model=model,
                                        dataloader=test_dataloader,
                                        loss_fn=loss_fn,
                                        device=device)
        # 4. Print 
        print(f"Epoch: {epoch} | Train Loss: {train_loss:.3f} | Train Acc: {train_acc:.3f} | Test Loss: {test_loss:.3f} | Test Acc: {test_acc:.3f}")

        # 5. update results
        results["train_loss"].append(train_loss)
        results["train_acc"].append(train_acc)
        results["test_loss"].append(test_loss)
        results["test_acc"].append(test_acc)

    # 6. return filled results
    return results

# Train and Evalute Model_0

In [None]:
torch.manual_seed(42)
torch.cuda.manual_seed(42)

NUM_EPOCHS = 5

model_0 = TinyVGG(input_shape=3,
                  hidden_units=10,
                  output_shape=len(train_data.classes)).to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model_0.parameters(),
                             lr=0.001)

from timeit import default_timer as timer
start_time = timer()

# train model 0
model_0_results = train(model=model_0,
                        train_dataloader=train_dataloader_simple,
                        test_dataloader=test_dataloader_simple,
                        optimizer=optimizer,
                        loss_fn=loss_fn,
                        epochs=NUM_EPOCHS,
                        device=device)

end_time = timer()
print(f"Total training time: {end_time-start_time:.3f} seconds")

# Plot loss curves of model_0

In [None]:
# Ge model_0 results keys
model_0_results.keys()

In [None]:
def plot_loss_curves(results: Dict[str,List[float]]):
    """
      Plots training curves of a results dict
    """
    # get lost values of the results dict
    loss = results["train_loss"]
    test_loss = results["test_loss"]

    accuracy = results["train_acc"]
    test_accuracy = results["test_acc"]

    epochs = range(len(results["train_loss"]))

    plt.figure(figsize=(15,7))

    plt.subplot(1,2,1)
    plt.plot(epochs,loss, label="train_loss")
    plt.plot(epochs,test_loss, label="test_loss")
    plt.title("Loss")
    plt.xlabel("Epochs")
    plt.legend()

    plt.subplot(1,2,2)
    plt.plot(epochs, accuracy, label="train_accuracy")
    plt.plot(epochs, test_accuracy, label="test_accuracy")
    plt.title("Accuracy")
    plt.xlabel("Epochs")
    plt.legend()


In [None]:
plot_loss_curves(model_0_results)

# Model 1

### TinyVGG with Data augmentation

In [None]:
from torchvision import transforms
train_transforms_trivial = transforms.Compose([
    transforms.Resize(size=(64,64)),
    transforms.TrivialAugmentWide(num_magnitude_bins=31),
    transforms.ToTensor()
])

test_transforms_simple = transforms.Compose([
    transforms.Resize(size=(64,64)),
    transforms.ToTensor()
])

In [None]:
from torchvision import datasets 

train_data_augmented = datasets.ImageFolder(root=train_dir,
                                            transform=train_transforms_trivial)

test_data_augmented = datasets.ImageFolder(root=test_dir,
                                            transform=test_transforms_simple)

In [None]:
import os
from torch.utils.data import DataLoader

BATCH_SIZE = 32
NUM_WORKERS = os.cpu_count()
print(f"Creating DataLoader's with batch size {BATCH_SIZE} and {NUM_WORKERS} workers.")

torch.manual_seed(42)

train_dataloader_augmented = DataLoader(dataset=train_data_augmented,
                                     batch_size=BATCH_SIZE,
                                     shuffle=True,
                                     num_workers=NUM_WORKERS)

test_dataloader_augmented = DataLoader(dataset=test_data_augmented,
                                     batch_size=BATCH_SIZE,
                                     shuffle=False,
                                     num_workers=NUM_WORKERS)

In [None]:
torch.manual_seed(42)

model_1 = TinyVGG(input_shape=3,
                  hidden_units=10,
                  output_shape=len(train_data_augmented)).to(device)
model_1

In [None]:
torch.manual_seed(42)
torch.cuda.manual_seed(42)

NUM_EPOCHS = 5

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model_1.parameters(),
                             lr=0.001)

from timeit import default_timer as timer 
start_time = timer()

model_1_results = train(model=model_1,
                        train_dataloader=train_dataloader_augmented,
                        test_dataloader=test_dataloader_simple,
                        optimizer=optimizer,
                        loss_fn=loss_fn,
                        epochs=NUM_EPOCHS,
                        device=device)

end_time = timer()
print(f"Total training time for model_1: {end_time-start_time:.3f} seconds")

In [None]:
plot_loss_curves(model_1_results)

# Compare model results

1. hard coding
2. PyTorch + Tensorboard
3. Weights & Biases
4. MLflow

In [None]:
import pandas as pd

model_0_df = pd.DataFrame(model_0_results)
model_1_df = pd.DataFrame(model_1_results)
model_0_df

In [None]:
# plots
plt.figure(figsize=(15,10))

epochs = range(len(model_0_df))

plt.subplot(2,2,1)
plt.plot(epochs, model_0_df["train_loss"], label="Model 0")
plt.plot(epochs, model_1_df["train_loss"], label="Model 1")
plt.title("Train Loss")
plt.xlabel("Epochs")
plt.legend()

plt.subplot(2,2,2)
plt.plot(epochs, model_0_df["test_loss"], label="Model 0")
plt.plot(epochs, model_1_df["test_loss"], label="Model 1")
plt.title("Test Loss")
plt.xlabel("Epochs")
plt.legend()

plt.subplot(2,2,3)
plt.plot(epochs, model_0_df["train_acc"], label="Model 0")
plt.plot(epochs, model_1_df["train_acc"], label="Model 1")
plt.title("Train Acc")
plt.xlabel("Epochs")
plt.legend()

plt.subplot(2,2,4)
plt.plot(epochs, model_0_df["test_acc"], label="Model 0")
plt.plot(epochs, model_1_df["test_acc"], label="Model 1")
plt.title("Test Acc")
plt.xlabel("Epochs")
plt.legend()