In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.utils.data import Dataset, DataLoader, random_split, TensorDataset
import torchvision.transforms as transforms
import torchvision.models
import matplotlib.pyplot as plt
import time

from PIL import ImageFile, Image
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [None]:
# Pytorch seed for reproducable results
torch.manual_seed(1000)
torch.set_deterministic(True)

  "torch.set_deterministic is deprecated and will be removed in a future "


In [None]:
# Get dataloader for the train, validation, and test set
def get_data_loader(batch_size):
    # Normalize images to the range [-1, 1]
    transform = transforms.Compose([
        transforms.ToTensor()
    ])

    # Load Dataset
    path = './Mushrooms'
    dataset = torchvision.datasets.ImageFolder(path, transform=transform)
    
    # Split into train, validation, and test sets
    num_images = len(dataset)
    train_len, val_len = int(0.7 * num_images), int(0.2 * num_images)
    test_len = num_images - train_len - val_len

    train_data, val_data, test_data = random_split(
        dataset,
        [train_len, val_len, test_len],
        generator=torch.Generator()
    )
    
    # Create dataloaders
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True)
    return train_loader, val_loader, test_loader

train_loader, val_loader, test_loader = get_data_loader(1)

# print(next(iter(train_loader)))
# plt.imshow(transforms.functional.to_pil_image(next(iter(train_loader))[0].squeeze(0)))



In [None]:
# Copy images to their respective folders
ROOT_DIR = './Mushrooms'
TRAIN_DIR = './Mush_train'
VAL_DIR = './Mush_val'
TEST_DIR = './Mush_test'
class_dirs = ['Agaricus', 'Amanita', 'Boletus', 'Cortinarius', 'Entoloma', 'Hygrocybe', 'Lactarius', 'Russula', 'Suillus']

os.mkdir(TRAIN_DIR)
os.mkdir(VAL_DIR)
os.mkdir(TEST_DIR)

image_count = 0

for data_pair in train_loader:
    image = transforms.functional.to_pil_image(data_pair[0][0])
    label = data_pair[1].item()

    folder_path = TRAIN_DIR + '/' + class_dirs[label]

    # create folder if it does not exist
    if not os.path.isdir(folder_path):
        os.mkdir(folder_path)

    image_path = folder_path + '/image_' + str(image_count) + '.jpg'
    image.save(image_path)

    image_count += 1


for data_pair in val_loader:
    image = transforms.functional.to_pil_image(data_pair[0][0])
    label = data_pair[1].item()

    folder_path = VAL_DIR + '/' + class_dirs[label]

    # create folder if it does not exist
    if not os.path.isdir(folder_path):
        os.mkdir(folder_path)

    image_path = folder_path + '/image_' + str(image_count) + '.jpg'
    image.save(image_path)

    image_count += 1


for data_pair in test_loader:
    image = transforms.functional.to_pil_image(data_pair[0][0])
    label = data_pair[1].item()

    folder_path = TEST_DIR + '/' + class_dirs[label]

    # create folder if it does not exist
    if not os.path.isdir(folder_path):
        os.mkdir(folder_path)

    image_path = folder_path + '/image_' + str(image_count) + '.jpg'
    image.save(image_path)

    image_count += 1

<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=88b4a261-3cf5-4bb1-820f-4791ebb8a30d' target="_blank">
 </img>
Created in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>