In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import numpy as np
import time
from tqdm.auto import tqdm
import os


torch.manual_seed(10)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f'Using {device} for inference')

In [None]:
### Get Dataset ###
PATH = 'MLRSNet Dataset\\MLRSNet for Semantic Scene Understanding\\Images' #TODO Change if Necessary
transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor()])
#Normalization is included as part of ENetB2

raw_dataset = torchvision.datasets.ImageFolder(PATH, transform=transform)

### Get Indices ###
NUM_CLASSES = 23 #TODO Change if needed
NUM_DATAPOINTS = 7000 #TODO Change if needed
num_each_class = int(np.ceil(NUM_DATAPOINTS / NUM_CLASSES))
start_points = [0]
for folder in os.listdir(PATH):
    # print(len(os.listdir(PATH+ f"\\{folder}")))
    next = start_points[-1] + len(os.listdir(PATH+f"\\{folder}"))
    start_points.append(next)
start_points.pop()
indices = []
for start in start_points:
    next_class_indices = list(range(start,start+num_each_class))
    indices = indices + next_class_indices
# print(len(indices))

### Get Subset of Dataset ###
# raw_subset = torch.utils.data.Subset(raw_dataset, np.arange(num_datapoints)) #This line is wrong - only gets 2 to 3 classes in
# raw_subset = torch.utils.data.Subset(raw_dataset, np.random.choice(len(raw_dataset), NUM_DATAPOINTS, replace=False))
raw_subset = torch.utils.data.Subset(raw_dataset, indices)


############## Code to take the images in order, so the ultimate test is on types of images never before seen in the train ################
# n = len(raw_dataset)
# n_train = int(0.6*n)
# n_val = int(0.8*n)
# train_dataset = torch.utils.data.Subset(raw_dataset, range(n_train))
# val_dataset = torch.utils.data.Subset(raw_dataset, range(n_train, n_val))
# test_dataset = torch.utils.data.Subset(raw_dataset, range(n_val, n+1))
# print("No. Images in train, val, test", len(train_dataset), len(val_dataset), len(test_dataset))


### Get Train/Val/Test Splits ###
n = len(raw_subset)
n_train = int(0.6*n)
n_val = int(0.2*n)
n_test = int(0.2*n)
leftover = n-n_train-n_test-n_val
n_test += leftover
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(raw_subset, [n_train, n_val, n_test])
# print("No. Images in train, val, test", len(train_dataset), len(val_dataset), len(test_dataset))

### Get Data Loaders ###
LOADERS_BATCH_SIZE = 25 #AP #TODO Change if needed
LOADERS_NUM_WORKERS = 2 #AP #TODO Change if needed

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=LOADERS_BATCH_SIZE, num_workers=LOADERS_NUM_WORKERS, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=LOADERS_BATCH_SIZE, num_workers=LOADERS_NUM_WORKERS, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=LOADERS_BATCH_SIZE, num_workers=LOADERS_NUM_WORKERS, shuffle=True)