Instructions:
- Perform the data exploration step (i.e. evaluate your data, # of observations, details about your data distributions, scales, missing data, column descriptions) Note: For image data you can still describe your data by the number of classes, # of images, plot example classes of the image, size of images, are sizes uniform? Do they need to be cropped? normalized? etc.
- Plot your data. For tabular data, you will need to run scatters, for image data, you will need to plot your example classes.
- How will you preprocess your data? You should explain this in your README.md file and link your Jupyter notebook to it. All code and  Jupyter notebooks have be uploaded to your repo.

In [1]:
!pip install pandas
!pip install opendatasets
!pip install Pillow



In [2]:

# import opendatasets as od
# import pandas as pd

# od.download(
#     "https://www.kaggle.com/datasets/kacpergregorowicz/house-plant-species")

### Import Necessary Libraries

In [2]:
import os
from PIL import Image, ImageOps
import numpy as np
import pandas as pd
import random
from torchvision import transforms
from tqdm import tqdm
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt

#Data Exploration

### Column Descriptions

In [3]:
# Set paths
data_dir = 'house-plant-species/house_plant_species/'  # Path where images are stored in subfolders by category
output_dir = 'processed_images/'    # Directory to save processed images by category
os.makedirs(output_dir, exist_ok=True)

# Create dataframe for plant categories

# Initialize a list to hold file paths and labels
file_paths = []
labels = []

# Loop through each category folder
for category in os.listdir(data_dir):
    category_path = os.path.join(data_dir, category)

    # Check if it's a directory
    if os.path.isdir(category_path):
        for img_file in os.listdir(category_path):
            file_paths.append(os.path.join(category_path, img_file))  # Full path to the image
            labels.append(category)  # Folder name is the label

# Create the DataFrame
labels_df = pd.DataFrame({
    'file_path': file_paths,
    'species': labels
})

print(labels_df.head())

                                           file_path  \
0  house-plant-species/house_plant_species/Jade p...   
1  house-plant-species/house_plant_species/Jade p...   
2  house-plant-species/house_plant_species/Jade p...   
3  house-plant-species/house_plant_species/Jade p...   
4  house-plant-species/house_plant_species/Jade p...   

                       species  
0  Jade plant (Crassula ovata)  
1  Jade plant (Crassula ovata)  
2  Jade plant (Crassula ovata)  
3  Jade plant (Crassula ovata)  
4  Jade plant (Crassula ovata)  


The file_path column holds the full file path to each image in the dataset. The species column corresponds to the species of each plant image. The species names are derived from the folder name containing the image, and serves as the label.

### Attributes of Dataset

In [4]:
# Count the number of images per species
species_counts = labels_df['species'].value_counts()
print(species_counts.to_string())

# Total number of images
total_images = labels_df.shape[0]

# Total number of unique classes (species)
total_classes = labels_df['species'].nunique()

print(f"Total number of images: {total_images}")
print(f"Total number of classes (species): {total_classes}")

species
Monstera Deliciosa (Monstera deliciosa)       547
Dumb Cane (Dieffenbachia spp.)                541
Chinese evergreen (Aglaonema)                 514
Lilium (Hemerocallis)                         480
Anthurium (Anthurium andraeanum)              455
ZZ Plant (Zamioculcas zamiifolia)             438
Daffodils (Narcissus spp.)                    421
Lily of the valley (Convallaria majalis)      416
Prayer Plant (Maranta leuconeura)             400
Snake plant (Sanseviera)                      396
Peace lily                                    385
Chinese Money Plant (Pilea peperomioides)     382
Money Tree (Pachira aquatica)                 359
Jade plant (Crassula ovata)                   353
Ctenanthe                                     347
Tulip                                         341
Tradescantia                                  341
Polka Dot Plant (Hypoestes phyllostachya)     341
African Violet (Saintpaulia ionantha)         337
Elephant Ear (Alocasia spp.)              

### Check for Corrupted Files

In [5]:
def check_images(directory):
    corrupted_files = []

    for subdir, dirs, files in os.walk(directory):
        for file in files:
            if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
                file_path = os.path.join(subdir, file)
                # Check if the image can be opened
                try:
                    img = Image.open(file_path)
                    img.verify()  # Verify that it is, in fact, an image
                except (IOError, SyntaxError) as e:
                    corrupted_files.append(file_path)
                    print(f"Corrupted file: {file_path}")

    return corrupted_files

# Specify the root directory of your image dataset
base_dir = 'house-plant-species/house_plant_species'
corrupted = check_images(base_dir)

print("Corrupted Files:", corrupted)


Corrupted Files: []


### Image Sizes

In [None]:
def check_image_sizes(df):
    sizes = []

    # Loop through each image path in the DataFrame
    for img_path in df['file_path']:
        with Image.open(img_path) as img:
            sizes.append(img.size)  # (width, height)

    # Convert to a set to get unique sizes
    unique_sizes = set(sizes)

    if len(unique_sizes) == 1:
        print("All images are uniform in size:", unique_sizes.pop())
    else:
        print("Images have varying sizes.")
        # Removed so GitHub preview works
        # print("Unique sizes:", unique_sizes)

# Check if images are uniform in size
check_image_sizes(labels_df)

Images have varying sizes.
Unique sizes: {(2160, 3183), (768, 515), (500, 435), (570, 759), (1076, 1305), (1200, 1600), (1687, 2362), (740, 416), (2600, 3900), (1468, 1500), (580, 580), (2820, 2820), (1284, 1599), (704, 1024), (446, 500), (750, 991), (1600, 1576), (960, 721), (1731, 1730), (1120, 1576), (767, 1080), (1600, 1167), (642, 640), (852, 895), (585, 780), (2375, 3000), (624, 1000), (1209, 1690), (1600, 1432), (1888, 2244), (2848, 4272), (900, 1200), (500, 592), (1280, 1570), (960, 842), (718, 900), (1024, 1235), (925, 884), (1600, 1288), (1355, 1600), (2024, 2024), (2048, 1149), (736, 539), (1080, 2335), (1536, 1532), (1920, 1000), (1500, 1609), (2686, 4029), (1528, 1528), (717, 1200), (640, 571), (994, 691), (750, 968), (1500, 1200), (1588, 2216), (2560, 2120), (3693, 2462), (1501, 1690), (982, 982), (818, 614), (794, 1160), (600, 800), (685, 1024), (1080, 1108), (1000, 999), (930, 1200), (6720, 4480), (1374, 2000), (418, 500), (184, 184), (1053, 1536), (744, 664), (1317, 12

### Example Classes

In [7]:
def plot_example_images(df, num_classes=5, images_per_class=3):
    # Select the unique species (classes) and limit to num_classes if specified
    classes = df['species'].unique()[:num_classes]

    plt.figure(figsize=(15, 3 * num_classes))  # Adjust figure size to fit the number of rows

    # Loop through each class and plot a few example images
    for i, species in enumerate(classes):
        # Filter DataFrame for the current species and get a sample of images
        sample_images = df[df['species'] == species].sample(images_per_class)['file_path']

        for j, img_path in enumerate(sample_images):
            # Calculate subplot index
            plt_idx = i * images_per_class + j + 1
            plt.subplot(num_classes, images_per_class, plt_idx)

            # Open and display the image
            img = Image.open(img_path)
            plt.imshow(img)
            plt.axis('off')

            # Add species label as title once per row
            if j == 1:
                plt.title(species, fontsize=10, fontweight='bold')

    plt.tight_layout()
    plt.show()
# Removed so GitHub preview works
# Plot 3 images for each of the example classes (adjust num_classes as needed)
# plot_example_images(labels_df, num_classes=47, images_per_class=3)

## EDA

# Preprocessing Images

### Remove Corrupted Images

In [9]:
def remove_corrupted_images(data_dir):
    for category in os.listdir(data_dir):
        category_path = os.path.join(data_dir, category)
        if os.path.isdir(category_path):
            for img_file in tqdm(os.listdir(category_path)):
                img_path = os.path.join(category_path, img_file)
                try:
                    with Image.open(img_path) as img:
                        img.verify()  # Check if the image can be opened
                except (IOError, SyntaxError):
                    os.remove(img_path)
                    print(f'Removed corrupted image: {img_file} in {category} folder')

remove_corrupted_images(data_dir)

  0%|          | 0/353 [00:00<?, ?it/s]

100%|██████████| 353/353 [00:00<00:00, 1806.49it/s]
100%|██████████| 291/291 [00:00<00:00, 2747.48it/s]
100%|██████████| 326/326 [00:00<00:00, 2446.44it/s]
100%|██████████| 189/189 [00:00<00:00, 1216.71it/s]
100%|██████████| 169/169 [00:00<00:00, 2945.65it/s]
100%|██████████| 266/266 [00:00<00:00, 3195.12it/s]
100%|██████████| 416/416 [00:00<00:00, 2093.35it/s]
100%|██████████| 400/400 [00:00<00:00, 4297.70it/s]
100%|██████████| 261/261 [00:00<00:00, 3695.00it/s]
100%|██████████| 252/252 [00:00<00:00, 3870.86it/s]
100%|██████████| 236/236 [00:00<00:00, 4349.20it/s]
100%|██████████| 130/130 [00:00<00:00, 3417.10it/s]
100%|██████████| 480/480 [00:00<00:00, 4096.43it/s]
100%|██████████| 243/243 [00:00<00:00, 2382.10it/s]
100%|██████████| 341/341 [00:00<00:00, 2821.17it/s]
100%|██████████| 66/66 [00:00<00:00, 4467.57it/s]
100%|██████████| 541/541 [00:00<00:00, 2606.26it/s]
100%|██████████| 421/421 [00:00<00:00, 3846.85it/s]
100%|██████████| 332/332 [00:00<00:00, 3501.52it/s]
100%|█████████

### Resize Images and Convert to a Consistent Format

In [10]:
# Check to see what image types we have in the original data
image_types = set()

# Traverse each subfolder and collect unique file extensions
for root, _, files in os.walk(data_dir):
    for file in files:
        # Get the file extension
        ext = os.path.splitext(file)[1].lower()  # Convert to lowercase for consistency
        image_types.add(ext)

print("Unique image types in the dataset:", image_types)

Unique image types in the dataset: {'', '.webp', '.gif', '.png', '.jpe', '.jpeg', '.jfif', '.jpg'}


In [11]:
def resize(data_dir, output_dir, size=(224, 224), format='JPEG'):
    accepted_types = {'.jpg', '.jpe', '.gif', '.jpeg', '.jfif', '.png', '.webp'}

    for category in os.listdir(data_dir):
        category_path = os.path.join(data_dir, category)
        output_category_path = os.path.join(output_dir, category)
        os.makedirs(output_category_path, exist_ok=True)

        if os.path.isdir(category_path):
            for img_file in tqdm(os.listdir(category_path)):
                ext = os.path.splitext(img_file)[1].lower()

                if ext in accepted_types:
                    img_path = os.path.join(category_path, img_file)
                    output_path = os.path.join(output_category_path, img_file.split('.')[0] + '.jpg')

                    try:
                        with Image.open(img_path) as img:
                            # Convert palette images with transparency to RGBA first
                            if img.mode == 'P' or img.mode == 'RGBA':
                                img = img.convert('RGBA')

                            # Resize
                            img = ImageOps.fit(img, size, Image.LANCZOS).convert('RGB')
                            img.save(output_path, format=format)
                    except Exception as e:
                        print(f'Error processing {img_file} in {category} folder: {e}')

resize(data_dir, output_dir)

100%|██████████| 353/353 [00:14<00:00, 23.84it/s]
100%|██████████| 291/291 [00:12<00:00, 23.10it/s]
100%|██████████| 326/326 [00:14<00:00, 22.81it/s]
100%|██████████| 189/189 [00:06<00:00, 27.53it/s]
100%|██████████| 169/169 [00:04<00:00, 38.53it/s]
100%|██████████| 266/266 [00:08<00:00, 31.96it/s]
100%|██████████| 416/416 [00:12<00:00, 33.20it/s]
100%|██████████| 400/400 [00:17<00:00, 23.41it/s]
100%|██████████| 261/261 [00:11<00:00, 23.50it/s]
100%|██████████| 252/252 [00:09<00:00, 26.86it/s]
100%|██████████| 236/236 [00:11<00:00, 19.93it/s]
100%|██████████| 130/130 [00:04<00:00, 32.11it/s]
100%|██████████| 480/480 [00:17<00:00, 28.23it/s]
100%|██████████| 243/243 [00:08<00:00, 27.47it/s]
100%|██████████| 341/341 [00:10<00:00, 31.06it/s]
100%|██████████| 66/66 [00:02<00:00, 29.29it/s]
100%|██████████| 541/541 [00:20<00:00, 26.02it/s]
100%|██████████| 421/421 [00:14<00:00, 28.48it/s]
100%|██████████| 332/332 [00:11<00:00, 29.06it/s]
100%|██████████| 306/306 [00:12<00:00, 24.65it/s]
10

In [14]:
# # Zip the output directory and download locally
# import shutil
# # from google.colab import files

# # Path to the output directory where resized images are saved
# output_dir = 'output_images'  # Replace with your output directory
# zip_file = 'resized_images.zip'

# # Compress the output directory into a zip file
# shutil.make_archive('resized_images', 'zip', output_dir)

# # Download the zip file to your local machine
# files.download('resized_images.zip')

### Label Encoding the Categories

In [15]:
# Get unique categories
category_labels = labels_df['species'].unique()
print("Category Labels:", category_labels)
print("Number of Categories:", len(category_labels))

Category Labels: ['Jade plant (Crassula ovata)' 'Rubber Plant (Ficus elastica)'
 'Schefflera' 'Areca Palm (Dypsis lutescens)'
 'Asparagus Fern (Asparagus setaceus)'
 'Iron Cross begonia (Begonia masoniana)'
 'Lily of the valley (Convallaria majalis)'
 'Prayer Plant (Maranta leuconeura)' 'Dracaena' 'Aloe Vera'
 'Begonia (Begonia spp.)' 'Kalanchoe' 'Lilium (Hemerocallis)'
 'Pothos (Ivy arum)' 'Polka Dot Plant (Hypoestes phyllostachya)' 'Yucca'
 'Dumb Cane (Dieffenbachia spp.)' 'Daffodils (Narcissus spp.)'
 'Elephant Ear (Alocasia spp.)' 'Poinsettia (Euphorbia pulcherrima)'
 'Calathea' 'Monstera Deliciosa (Monstera deliciosa)'
 'Hyacinth (Hyacinthus orientalis)' 'Sago Palm (Cycas revoluta)'
 'Chrysanthemum' 'Ponytail Palm (Beaucarnea recurvata)'
 'Anthurium (Anthurium andraeanum)' 'Tradescantia'
 'Chinese Money Plant (Pilea peperomioides)'
 'Chinese evergreen (Aglaonema)' 'Tulip'
 'Parlor Palm (Chamaedorea elegans)' 'Peace lily'
 'ZZ Plant (Zamioculcas zamiifolia)' 'Venus Flytrap'
 'Chris

In [16]:
# Save labels_df to csv for later use
labels_df.to_csv('labels_df.csv', index=False)

### Label Encoding and One-Hot Encoding
We have two encoding functions that we can use. Label-encoding is more general, but one-hot is more commonly used for non-ordinal categories. However, because we have 47 classes, one-hot encoding would be computationally more expensive and may lead to the curse of dimensionality.

In [17]:
# Label Encoding
label_encoder = LabelEncoder()

# Fit the encoder and transform the species column
labels_df['encoded_label'] = label_encoder.fit_transform(labels_df['species'])

# Display the mapping of labels to integers
label_mapping = dict(zip(label_encoder.classes_, label_encoder.transform(label_encoder.classes_)))
print("Label Encoding Mapping:", label_mapping)

labels_df['encoded_label']

Label Encoding Mapping: {'African Violet (Saintpaulia ionantha)': 0, 'Aloe Vera': 1, 'Anthurium (Anthurium andraeanum)': 2, 'Areca Palm (Dypsis lutescens)': 3, 'Asparagus Fern (Asparagus setaceus)': 4, 'Begonia (Begonia spp.)': 5, 'Bird of Paradise (Strelitzia reginae)': 6, 'Birds Nest Fern (Asplenium nidus)': 7, 'Boston Fern (Nephrolepis exaltata)': 8, 'Calathea': 9, 'Cast Iron Plant (Aspidistra elatior)': 10, 'Chinese Money Plant (Pilea peperomioides)': 11, 'Chinese evergreen (Aglaonema)': 12, 'Christmas Cactus (Schlumbergera bridgesii)': 13, 'Chrysanthemum': 14, 'Ctenanthe': 15, 'Daffodils (Narcissus spp.)': 16, 'Dracaena': 17, 'Dumb Cane (Dieffenbachia spp.)': 18, 'Elephant Ear (Alocasia spp.)': 19, 'English Ivy (Hedera helix)': 20, 'Hyacinth (Hyacinthus orientalis)': 21, 'Iron Cross begonia (Begonia masoniana)': 22, 'Jade plant (Crassula ovata)': 23, 'Kalanchoe': 24, 'Lilium (Hemerocallis)': 25, 'Lily of the valley (Convallaria majalis)': 26, 'Money Tree (Pachira aquatica)': 27, '

0        23
1        23
2        23
3        23
4        23
         ..
14785     7
14786     7
14787     7
14788     7
14789     7
Name: encoded_label, Length: 14790, dtype: int64

In [18]:
# Save the label encodings
label_mapping = {key: int(value) for key, value in label_mapping.items()}

# Save the label mapping for reference
import json
with open('label_mapping.json', 'w') as f:
    json.dump(label_mapping, f)


In [19]:
# One-Hot Encoding
one_hot_labels = pd.get_dummies(labels_df['species'], prefix='species')

# Concatenate the one-hot encoded columns back to the original DataFrame
labels_df = pd.concat([labels_df, one_hot_labels], axis=1)

# Drop the original 'species' column if it’s no longer needed
labels_df = labels_df.drop('species', axis=1)

# Display the result
labels_df.head()

Unnamed: 0,file_path,encoded_label,species_African Violet (Saintpaulia ionantha),species_Aloe Vera,species_Anthurium (Anthurium andraeanum),species_Areca Palm (Dypsis lutescens),species_Asparagus Fern (Asparagus setaceus),species_Begonia (Begonia spp.),species_Bird of Paradise (Strelitzia reginae),species_Birds Nest Fern (Asplenium nidus),...,species_Rattlesnake Plant (Calathea lancifolia),species_Rubber Plant (Ficus elastica),species_Sago Palm (Cycas revoluta),species_Schefflera,species_Snake plant (Sanseviera),species_Tradescantia,species_Tulip,species_Venus Flytrap,species_Yucca,species_ZZ Plant (Zamioculcas zamiifolia)
0,house-plant-species/house_plant_species/Jade p...,23,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
1,house-plant-species/house_plant_species/Jade p...,23,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
2,house-plant-species/house_plant_species/Jade p...,23,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
3,house-plant-species/house_plant_species/Jade p...,23,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
4,house-plant-species/house_plant_species/Jade p...,23,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False


In [20]:
# Save the one-hot encodings
labels_df.to_csv('one_hot_encoded_labels.csv', index=False)

In [None]:
import os
import shutil
from sklearn.model_selection import train_test_split

ROOT_DIR = 'processed_images'
TRAIN_DIR = 'images_train'
VAL_DIR = 'images_val'
TEST_DIR = 'images_test'
SPLIT_RATIO = 0.2
TEST_RATIO = 0.15
VAL_RATIO = 0.1765
RANDOM_STATE = 42

os.makedirs(TRAIN_DIR, exist_ok=True)
os.makedirs(VAL_DIR, exist_ok=True)
os.makedirs(TEST_DIR, exist_ok=True)

for class_name in os.listdir(ROOT_DIR):
    class_path = os.path.join(ROOT_DIR, class_name)
    if os.path.isdir(class_path):
        if ".DS_Store" in class_path:
            continue

        images = os.listdir(class_path) # gets all images in class

        train_val_images, test_images = train_test_split(images, test_size=TEST_RATIO, random_state=RANDOM_STATE)
        train_images, val_images = train_test_split(train_val_images, test_size=VAL_RATIO, random_state=RANDOM_STATE)

        train_class_dir = os.path.join(TRAIN_DIR, class_name)
        val_class_dir = os.path.join(VAL_DIR, class_name)
        test_class_dir = os.path.join(TEST_DIR, class_name)
        os.makedirs(train_class_dir, exist_ok=True)
        os.makedirs(val_class_dir, exist_ok=True)
        os.makedirs(test_class_dir, exist_ok=True)

        # Copy train images
        for img_name in train_images:
            shutil.copy(os.path.join(class_path, img_name), os.path.join(train_class_dir, img_name))

        # Copy validation images
        for img_name in val_images:
            shutil.copy(os.path.join(class_path, img_name), os.path.join(val_class_dir, img_name))

        # Copy test images
        for img_name in test_images:
            shutil.copy(os.path.join(class_path, img_name), os.path.join(test_class_dir, img_name))

In [22]:
import torch
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import torch.nn as nn
import torch.nn.functional as F

dataset = datasets.ImageFolder(root=data_dir)

num_classes = len(dataset.classes)

In [23]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),                # Resize to the desired size
    transforms.ToTensor(),                        # Convert PIL Image to Tensor
    transforms.Normalize((0.5,), (0.5,)),         # Normalize (mean=0.5, std=0.5 for grayscale)
])

In [24]:
train_dataset = datasets.ImageFolder(root=TRAIN_DIR, transform=transform)
val_dataset = datasets.ImageFolder(root=VAL_DIR, transform=transform)
test_dataset = datasets.ImageFolder(root=TEST_DIR, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [25]:
def train(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for batch_idx, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)

        # Print progress every few batches
        if batch_idx % 10 == 0:
            print(f"Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}")

    return running_loss / len(train_loader.dataset)


def evaluate(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

In [26]:
import torch.nn as nn
import torch.nn.functional as F

class PlantClassifierCNN(nn.Module):
    def __init__(self, num_classes):
        super(PlantClassifierCNN, self).__init__()
        # Define layers of the CNN
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)  # 3 input channels for RGB
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * 56 * 56, 128)  # Adjust based on image size
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 56 * 56)  # Flatten
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Initialize model, loss function, and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PlantClassifierCNN(num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [27]:
num_epochs = 2

for epoch in range(num_epochs):
    train_loss = train(model, train_loader, criterion, optimizer, device)
    test_accuracy = evaluate(model, test_loader, device)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}, Test Accuracy: {test_accuracy:.4f}", flush=True)

Batch 0/167, Loss: 3.8636
Batch 10/167, Loss: 3.8696
Batch 20/167, Loss: 3.8264
Batch 30/167, Loss: 3.8293
Batch 40/167, Loss: 3.8594
Batch 50/167, Loss: 3.7293
Batch 60/167, Loss: 3.8098
Batch 70/167, Loss: 3.7769
Batch 80/167, Loss: 3.7560
Batch 90/167, Loss: 3.5810
Batch 100/167, Loss: 3.6334
Batch 110/167, Loss: 3.7423
Batch 120/167, Loss: 3.6191
Batch 130/167, Loss: 3.6332
Batch 140/167, Loss: 3.5511
Batch 150/167, Loss: 3.6769
Batch 160/167, Loss: 3.5501
Epoch [1/2], Loss: 3.9020, Test Accuracy: 0.0870
Batch 0/167, Loss: 3.5658
Batch 10/167, Loss: 3.5257
Batch 20/167, Loss: 3.3762
Batch 30/167, Loss: 3.2695
Batch 40/167, Loss: 3.3924
Batch 50/167, Loss: 3.3137
Batch 60/167, Loss: 3.2719
Batch 70/167, Loss: 3.3123
Batch 80/167, Loss: 3.2257
Batch 90/167, Loss: 3.3731
Batch 100/167, Loss: 2.9765
Batch 110/167, Loss: 3.1187
Batch 120/167, Loss: 3.1899
Batch 130/167, Loss: 3.0829
Batch 140/167, Loss: 3.1663
Batch 150/167, Loss: 3.1456
Batch 160/167, Loss: 3.0145
Epoch [2/2], Loss: 3.

In [28]:
test_accuracy = evaluate(model, test_loader, device)

In [29]:
test_accuracy

0.18250780901383312

In [31]:
train_accuracy = evaluate(model, train_loader, device)
train_accuracy

0.24558602554470324

In [32]:
validation_accuracy = evaluate(model, val_loader, device)
validation_accuracy

0.24558602554470324