<a href="https://colab.research.google.com/github/geraldmc/torch-draft-final_project/blob/main/load_deepweeds.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### DeepWeeds Train, Eval, Test (Colab)*


i) `main.ipynb`: this notebook is used to alter, train, evaluate, and test a Pytorch ResNet-50 model for k-folds. The notebook was developed to be run on Colab and includes dependencies specific to that environment.

  

#### Data dependency

Steps required to setup the runtime environment on Colab are as follows:

 - Download the code from Github (run without change)
 - Import the project code (run without change).    
 - Download and unzip the dataset from Drive

**Note**: The third step is specific to user and environment as it depends on accessing Google Drive. The file `params.py` under the project code directory `conf` will have to be customized.

#### To Run
- Resolve step 3 (above) for your environment. 
- Download the DeepWeeds data from here:
  [images.zip](https://drive.google.com/file/d/1xnK3B6K6KekDI55vwJ0vnc2IGoDga9cj) (468 MB)
- Execute `Run All` in Colab. 


In [None]:
import os 
import os.path
import time
from datetime import datetime
import glob 
import shutil
import pickle
import copy
import pathlib
from zipfile import ZipFile
import pandas as pd
import numpy as np
import sklearn.metrics as sm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torch.utils.data import RandomSampler, random_split
from torch.utils.data import SubsetRandomSampler, WeightedRandomSampler
from torchvision.datasets import ImageFolder

### Download the code from Github

In [1]:
if os.path.isfile("../main.zip"):
  print ('Have already downloaded the project file, continuing...')
  print()
else:
  print ('Downloading file...')
  ! wget https://github.com/geraldmc/torch-draft-final_project/archive/refs/heads/main.zip
  ! unzip -qq main.zip
  %cd torch-draft-final_project-main

Downloading file...
--2022-03-29 22:06:09--  https://github.com/geraldmc/torch-draft-final_project/archive/refs/heads/main.zip
Resolving github.com (github.com)... 52.192.72.89
Connecting to github.com (github.com)|52.192.72.89|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://codeload.github.com/geraldmc/torch-draft-final_project/zip/refs/heads/main [following]
--2022-03-29 22:06:09--  https://codeload.github.com/geraldmc/torch-draft-final_project/zip/refs/heads/main
Resolving codeload.github.com (codeload.github.com)... 52.193.111.178
Connecting to codeload.github.com (codeload.github.com)|52.193.111.178|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified [application/zip]
Saving to: ‘main.zip’

main.zip                [ <=>                ] 526.76K  2.65MB/s    in 0.2s    

2022-03-29 22:06:11 (2.65 MB/s) - ‘main.zip’ saved [539403]

/content/torch-draft-final_project-main


### Import the project.

In [None]:
try:
  import conf.params as params
  from data import transforms as tsf
  from data.test_loader import DeepWeeds_Test
except ImportError:
  pass

### Download the dataset from Drive.

In [3]:
from google.colab import drive

drive.mount('/content/gdrive')
print()
print("Downloading DeepWeeds images to " + params.IMG_ZIP_FILE)
!cp '{params.GD_ZIP_IMG}' '{params.IMG_ZIP_FILE}'
print()
!ls -lart {params.IMG_ZIP_FILE}

print()
print("Downloading GAN images to " + params.GAN_ZIP_FILE)
!cp '{params.GD_ZIP_GAN}' '{params.GAN_ZIP_FILE}'
print()
!ls -lart {params.GAN_ZIP_FILE}

Mounted at /content/gdrive
Downloading DeepWeeds images to /content/torch-draft-final_project-main/data/images.zip

-rw------- 1 root root 491516047 Mar 29 22:07 /content/torch-draft-final_project-main/data/images.zip


### Unzip the data files.

In [None]:
print("[INFO] Unzipping DeepWeeds images into " +  params.IMG_DIRECTORY)

with ZipFile(params.IMG_ZIP_FILE, "r") as zip_ref:
  zip_ref.extractall(params.IMG_DIRECTORY)

img_list=os.listdir(params.IMG_DIRECTORY)
print(len(img_list))

print()
print("[INFO] Unzipping GAN image dirs into " + params.DATA_PATH)

with ZipFile(params.GAN_ZIP_FILE, "r") as zip_ref:
  zip_ref.extractall(params.DATA_PATH)

gan_dir_list=os.listdir(params.DATA_PATH+'/gans/train/0')
print(len(gan_dir_list))

### Get the labels.

In [7]:
LABEL_PATH = os.path.join(params.DATA_PATH, 'labels')
label_df = pd.read_csv(os.path.join(LABEL_PATH, 'labels.csv'))
None

labels.csv	  test_subset3.csv   train_subset2.csv	val_subset1.csv
test_subset0.csv  test_subset4.csv   train_subset3.csv	val_subset2.csv
test_subset1.csv  train_subset0.csv  train_subset4.csv	val_subset3.csv
test_subset2.csv  train_subset1.csv  val_subset0.csv	val_subset4.csv


## Steps For Train and Evaluate

##### 0) Get files in order (mainly Colab-specific).

    1) Instantiate new data loaders.
    2) Init a new ResNet50 model.
    3) Get/set the parameters to be optimized/updated.
    4) Train the model. Save the best model.
    5) Delete contents of the train/val directories.
    6) REPEAT 1-6.

### 1) Supporting functions.


### 0) Functions for Colab - getting train, test, val files in place.

In [None]:
def get_file_list():
  files = []
  for dirpath, dirnames, filenames in os.walk(params.IMAGE_PATH):
    for file in filenames:
      files.append(file)
  return files

def copy_files(df, files, filepath):
  labels = dict(zip(df.Filename, df.Label)) 
  for f in files:
    try:
      src = os.path.join(params.IMG_DIRECTORY, f)
      dst = os.path.join(filepath, str(labels[f]), f)
      shutil.copyfile(src, dst)
    except KeyError:
      pass

def delete_train_val_files(path):
  for sub_dir in sorted(os.listdir(path)):
    for file_name in os.listdir(os.path.join(path, sub_dir)):
      file = os.path.join(path, sub_dir, file_name)
      if os.path.isfile(file):
        os.remove(file)

def copy_test_files(df, filepath):
  for f in df.Filename:
    try:
      src = os.path.join(params.IMG_DIRECTORY, f)
      dst = os.path.join(filepath, f)
      shutil.copyfile(src, dst)
    except KeyError:
      pass

def delete_test_files(path):
  for file_name in os.listdir(params.IMG_TEST_PATH):
    file = os.path.join(path, file_name)
    if os.path.isfile(file):
      os.remove(file)

def get_dataloader_counts(dl):
  from collections import Counter
  try:
    train_dict = dict(Counter(
      dl['train'].dataset.datasets[0].targets))
    val_dict = dict(Counter(
      dl['val'].dataset.datasets[0].targets))
  except AttributeError:
    #print('error')
    train_dict = dict(Counter(
      dl['train'].dataset.targets))
    val_dict = dict(Counter(
      dl['val'].dataset.targets))
 
  return train_dict, val_dict

def pickler(a, filename):
  with open(filename+'.pickle', 'wb') as handle:
    pickle.dump(a, handle, protocol=pickle.HIGHEST_PROTOCOL)

def unpickler(filename):
  with open(filename, 'rb') as handle:
    b = pickle.load(handle)
  return b

 ### 1a) Prepare single data loader.

In [None]:
def get_single_dataloader(batch_size):
  '''Creates train and validation datasets and dataloaders.
  The default transform is applied to each. See data/transforms.py.
  '''
  train_data_single = ImageFolder(
    root=params.IMG_TRAIN_PATH, 
    transform=tsf.base_transform)
  train_loader_single = DataLoader(train_data_single, 
    batch_size=batch_size, shuffle=True, 
    num_workers=2,
    pin_memory=torch.cuda.is_available())

  val_data_single = ImageFolder(
    root=params.IMG_VAL_PATH, 
    transform=tsf.base_transform)
  val_loader_single = DataLoader(val_data_single, 
    batch_size=batch_size, shuffle=False, 
    num_workers=2,
    pin_memory=torch.cuda.is_available())

  dataloaders_gan = {}
  dataloaders_gan['train'] = train_loader_single
  dataloaders_gan['val'] = val_loader_single

  return dataloaders_gan

 ### 1b) Prepare DeepWeeds augmented data loader (version 1).

In [None]:
def get_DW_dataloaders1(batch_size):
  '''Creates train and validation datasets and dataloaders.
  The default, hvflip and transform and jitter_hue transfroms 
  are applied. Triples the number of samples in each. 
  See data/transforms.py.
  '''
  # Each training dataset contains 8382 x 3 images.
  # Shuffle is True for train, False for val.

  train_loader_aug = DataLoader(
  ConcatDataset([ImageFolder(
      params.IMG_TRAIN_PATH, 
      transform=tsf.data_transforms['default']),
    ImageFolder(
      params.IMG_TRAIN_PATH,
      transform=tsf.data_transforms['hvflip']),
    ImageFolder(
      params.IMG_TRAIN_PATH,
      transform=tsf.data_transforms['jitter_hue'])]), 
      batch_size=batch_size, 
      shuffle=True, num_workers=2, 
      pin_memory=torch.cuda.is_available())

  # Each validation dataset contains 3251 x 3 images.

  val_loader_aug = DataLoader(
  ConcatDataset([ImageFolder(
      params.IMG_VAL_PATH, 
      transform=tsf.data_transforms['default']),
    ImageFolder(
      params.IMG_VAL_PATH,
      transform=tsf.data_transforms['hvflip']),
    ImageFolder(
      params.IMG_VAL_PATH,
      transform=tsf.data_transforms['jitter_hue'])]), 
      batch_size=batch_size, 
      shuffle=False, num_workers=2, # shuffle is False for val!
      pin_memory=torch.cuda.is_available())

  dataloaders_aug = {}
  dataloaders_aug['train'] = train_loader_aug
  dataloaders_aug['val'] = val_loader_aug

  #print("Cumulative length of train:", dataloaders_aug['train'].dataset.cumulative_sizes)
  #print("Cumulative length of val:", dataloaders_aug['val'].dataset.cumulative_sizes)

  return dataloaders_aug

 ### 1c) Prepare DeepWeeds augmented data loader (version 2).

In [None]:
def get_DW_dataloaders2(batch_size):
  '''Creates train and validation datasets and dataloaders.
  The hvflip, default, ImageNet_autoaug and jitter_hue transfroms 
  are applied. Triples the number of samples in each. 
  See data/transforms.py.
  '''

  train_loader_aug = DataLoader(
  ConcatDataset([ImageFolder(
      params.IMG_TRAIN_PATH, 
      transform=tsf.data_transforms['hvflip']),
    ImageFolder(
      params.IMG_VAL_PATH,
      transform=tsf.data_transforms['default']),
    ImageFolder(
      params.IMG_TRAIN_PATH,
      transform=tsf.data_transforms['ImageNet_autoaug']),
    ImageFolder(
      params.IMG_TRAIN_PATH,
      transform=tsf.data_transforms['jitter_hue'])]), 
      batch_size=batch_size, 
      shuffle=True, num_workers=2, 
      pin_memory=torch.cuda.is_available())

  # Each validation dataset contains 3251 x 4 images.

  val_loader_aug = DataLoader(
  ConcatDataset([ImageFolder(
      params.IMG_VAL_PATH, 
      transform=tsf.data_transforms['hvflip']),
    ImageFolder(
      params.IMG_VAL_PATH,
      transform=tsf.data_transforms['default']),
    ImageFolder(
      params.IMG_VAL_PATH,
      transform=tsf.data_transforms['ImageNet_autoaug']),
    ImageFolder(
      params.IMG_VAL_PATH,
      transform=tsf.data_transforms['jitter_hue'])]), 
      batch_size=batch_size, 
      shuffle=False, num_workers=2, # shuffle is False for val!
      pin_memory=torch.cuda.is_available())

  dataloaders_aug = {}
  dataloaders_aug['train'] = train_loader_aug
  dataloaders_aug['val'] = val_loader_aug

  #print("Cumulative length of train:", dataloaders_aug['train'].dataset.cumulative_sizes)
  #print("Cumulative length of val:", dataloaders_aug['val'].dataset.cumulative_sizes)

  return dataloaders_aug

 ### 1b) Prepare combined data loaders (DeepWeed images, with GANs).

In [None]:
def get_comb_dataloaders(transform_method='original'):
  '''Creates combined train and validation datasets and dataloaders.
  See data/transforms.py.
  '''
  
  GAN_TRAIN_PATH = os.path.join(params.DATA_PATH, 'gans/train')

  image_list = []

  if transform_method == 'original':
    image_list.append(ImageFolder(
      root=GAN_TRAIN_PATH, 
      transform=tsf.data_transforms['original']))
    image_list.append(ImageFolder(
      root=params.IMG_TRAIN_PATH, 
      transform=tsf.data_transforms['original']))

  elif transform_method == 'random':
    image_list.append(ImageFolder(
      root=GAN_TRAIN_PATH, 
      transform=tsf.data_transforms['random_augment']))
    image_list.append(ImageFolder(
      root=params.IMG_TRAIN_PATH, 
      transform=tsf.data_transforms['random_augment']))

  elif transform_method == 'auto':
    image_list.append(ImageFolder(
      root=GAN_TRAIN_PATH, 
      transform=tsf.data_transforms['ImageNet_autoaug']))
    image_list.append(ImageFolder(
      root=params.IMG_TRAIN_PATH, 
      transform=tsf.data_transforms['ImageNet_autoaug']))

  elif transform_method == 'grayscale':
    image_list.append(ImageFolder(
      root=GAN_TRAIN_PATH, 
      transform=tsf.data_transforms['grayscale']))
    image_list.append(ImageFolder(
      root=params.IMG_TRAIN_PATH, 
      transform=tsf.data_transforms['grayscale']))

  elif transform_method == 'translate':
    image_list.append(ImageFolder(
      root=GAN_TRAIN_PATH, 
      transform=tsf.data_transforms['translate']))
    image_list.append(ImageFolder(
      root=params.IMG_TRAIN_PATH, 
      transform=tsf.data_transforms['translate']))
  else:
      pass #FIXME, handle this somehow


  image_datasets = ConcatDataset(image_list)

  img_sets = dict()
  img_sets['train'], img_sets['val'] = random_split(image_datasets, 
                                      (round(0.8*len(image_datasets)), 
                                      round(0.2*len(image_datasets))))

  combined_train_loader = DataLoader(img_sets['train'], 
                                  batch_size=32, shuffle=True, 
                                  num_workers=2)

  combined_val_loader = DataLoader(img_sets['val'], 
                                  batch_size=32, shuffle=True, 
                                  num_workers=2)
  dataloaders = {}
  dataloaders['train'] = combined_train_loader
  dataloaders['val'] = combined_val_loader

  return dataloaders

### Init a new ResNet50 model.

Steps:
  1) If feature extracting then set the required parameters.
  2) If model other than ReNet, must set manually (for now). 

In [None]:
def set_parameter_requires_grad(model, feature_extracting):
    '''Sets required network params for feature extraction only. 
    '''
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

def initialize_model(model_name, num_classes, feature_extract):
  '''This will download a pretrained ResNet50 model and
  alter it to suit the number of classes in our dataset. 
  '''
    # Init a new ResNet50 model (called below)
  model_ft = None
  input_size = 0
  if model_name == "resnet50":
    """ Resnet50
    """
    model_ft = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
    set_parameter_requires_grad(model_ft, feature_extract)
    num_ftrs = model_ft.fc.in_features
    model_ft.fc = nn.Linear(num_ftrs, num_classes)
    input_size = 224
  else:
    print("Invalid model name, exiting...")
    exit()

  return model_ft, input_size

def init_model():
  '''Init model
  '''
  model, input_size = initialize_model('resnet50', params.NUM_CLASSES, 
                                          feature_extract=True)
  if torch.cuda.is_available():
    model.to('cuda') #IMPORTANT!
  
  return model, input_size

### Get/set the parameters to be optimized/updated for each k-fold.

In [None]:
def get_parameters(model, features):
  '''  Only parameters that we've just initialized, i.e. the parameters with 
  requires_grad is True, are updated. (i.e. the last fc layer).
  '''

  params_to_update = model.parameters()

  print("[INFO] Params to learn:")
  if features:
    params_to_update = []
    for name,param in model.named_parameters():
      if param.requires_grad == True:
        params_to_update.append(param)
        print("\t",name)
  else:
    for name,param in model.named_parameters():
      if param.requires_grad == True:
        print("\t",name)
  
  print()

  # Observe that all parameters are optimized
  # optimizer_ft = optim.SGD(params_to_update, lr=0.001, momentum=0.9)
  opt = optim.Adam(params_to_update, lr=1e-3)
  sch = optim.lr_scheduler.ReduceLROnPlateau(
      opt, patience=16, factor=0.5, min_lr=0.00003125)

  return opt, sch

### Train and evaluate the model in one pass.

In [None]:
'''
    .                       o8o
  .o8                       `"'
.o888oo oooo d8b  .oooo.   oooo  ooo. .oo.
  888   `888""8P `P  )88b  `888  `888P"Y88b
  888    888      .oP"888   888   888   888
  888 .  888     d8(  888   888   888   888
  "888" d888b    `Y888""8o o888o o888o o888o
'''

def train_model(model, dataloaders, criterion, optimizer, scheduler, num_epochs):
  '''Function to train and validate concurrently. 
  '''
  since = time.time()

  # lists to store per-epoch loss and accuracy values
  val_acc_history, val_loss_history = [], []
  train_acc_history, train_loss_history = [], []

  best_model_wts = copy.deepcopy(model.state_dict())
  best_acc = 0.0

  for epoch in range(num_epochs):
    print('\t[INFO] Epoch {}/{}'.format(epoch, num_epochs - 1))
    print('\t' + '-' * 16)

    # Each epoch has a training and validation phase
    for phase in ['train', 'val']:
      if phase == 'train':
        model.train()  # Set model to training mode
      else:
        model.eval()   # Set model to evaluate mode

      running_loss = 0.0
      running_corrects = 0

      # Iterate over data.
      for inputs, labels in dataloaders[phase]:
        inputs = inputs.to(params.DEVICE)
        labels = labels.to(params.DEVICE)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward
        # track history if only in train
        with torch.set_grad_enabled(phase == 'train'):
          # Get model outputs and calculate loss
          outputs = model(inputs)
          loss = criterion(outputs, labels)

          _, preds = torch.max(outputs, 1)

          # backward + optimize only if in training phase
          if phase == 'train':
            loss.backward()
            optimizer.step()
          #else: # val mode
              #scheduler.step(loss) # optimizer to scheduler for eval

        # statistics
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)

      epoch_loss = running_loss / len(dataloaders[phase].dataset)
      epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

      print('\t{} loss: {:.4f} acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

      # deep copy the model
      if phase == 'val' and epoch_acc > best_acc:
        best_acc = epoch_acc
        best_model_wts = copy.deepcopy(model.state_dict())
      if phase == 'val':
        val_acc_history.append(epoch_acc)
        val_loss_history.append(epoch_loss)
      elif phase == 'train':
        train_acc_history.append(epoch_acc)
        train_loss_history.append(epoch_loss)
    print()
  time_elapsed = time.time() - since
  print('[INFO] Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
  print('[INFO] Best val Acc: {:4f}'.format(best_acc))

  # load best model weights
  model.load_state_dict(best_model_wts)
  return model, val_acc_history, val_loss_history, train_acc_history, train_loss_history, best_acc

### Functions to kick off training/eval and to save resulting model. 

In [None]:
def train(model, dataloaders, optimizer, scheduler, epochs):
  '''Function to call train/eval during each fold.
  '''
  criterion = nn.CrossEntropyLoss() # (i.e. binary_crossentropy)
  model, va, vl, ta, tl, best_acc = train_model(model, dataloaders, criterion, 
                                                optimizer, scheduler, num_epochs=epochs)
  return model, va, vl, ta, tl, best_acc

def save_model(m, name):
  '''Save model as a state dictionary. 
  '''
  # provide a timestamp
  timestamp = datetime.fromtimestamp(time.time()).strftime('%Y%m%d-%H%M%S')
  saved_name = os.path.join(params.OUTPUT_PATH, str(timestamp) + name + '_model.pth')
  torch.save(m.state_dict(), saved_name)
  return saved_name

### Function to make accuracy and loss plots. 

In [None]:
def make_plots(va, vl, ta, tl):
  '''Plot train/val loss and accuracy
  '''
  # filenames to save plots
  timestamp = datetime.fromtimestamp(time.time()).strftime('%Y%m%d-%H%M%S')
  ta_va = 'ta_va_' + timestamp
  tl_vl = 'tl_vl_' + timestamp
  ta_tl = 'ta_tl_' + timestamp
  va_vl = 'va_vl_' + timestamp

  # Convert tensor objects to lists
  val_acc_record = [va[x].item() for x in range(len(va))]
  val_loss_record = [vl[x] for x in range(len(vl))]
  train_acc_record = [ta[x].item() for x in range(len(ta))]
  train_loss_record = [tl[x] for x in range(len(tl))]
  
  # Accuracy plots
  plt.figure(figsize=(6, 4))
  plt.plot(train_acc_record, color='green', label='train acc')
  plt.plot(val_acc_record, color='blue', label='val acc')
  plt.xlabel('Epochs')
  plt.ylabel('Accuracy')
  plt.legend()
  plt.savefig(f"output/{ta_va}.png")
  #plt.show()

  # Loss plots
  plt.figure(figsize=(6, 4))
  plt.plot(train_loss_record, color='orange', label='train loss')
  plt.plot(val_loss_record, color='red', label='val loss')
  plt.xlabel('Epochs')
  plt.ylabel('Loss')
  plt.legend()
  plt.savefig(f"output/{tl_vl}.png")
  #plt.show()

  # Train acc versus loss
  plt.figure(figsize=(6, 4))
  plt.plot(train_acc_record, color='blue', label='train acc')
  plt.plot(train_loss_record, color='green', label='train loss')
  plt.xlabel('Epochs')
  plt.ylabel('Accuracy/Loss')
  plt.legend()
  plt.savefig(f"output/{ta_tl}.png")
  #plt.show()

  # Val acc versus loss
  plt.figure(figsize=(6, 4))
  plt.plot(val_acc_record, color='red', label='val acc')
  plt.plot(val_loss_record, color='orange', label='val loss')
  plt.xlabel('Epochs')
  plt.ylabel('Accuracy/Loss')
  plt.legend()
  plt.savefig(f"output/{va_vl}.png")
  #plt.show()

### Function to run k-folds. The default is 5. 

In [None]:
def run_train_kfold(loader, batch):
  '''Run kfolds on training. Fold images are loaded from a csv file. 
  '''
  files = get_file_list()
  best_epoch_accs = []

  # K-fold cross validation, saving outputs for each fold.
  for idx in range(params.FOLDS):
      
    timestamp = datetime.fromtimestamp(time.time()).strftime('%Y%m%d-%H%M%S')
    print()
    print('[INFO] Fold {}/{} - {}'.format(idx + 1, params.FOLDS, timestamp))
    output_directory = "{}/{}/".format(params.OUTPUT_PATH, timestamp)
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)

    train_label_file = "{}/train_subset{}.csv".format(LABEL_PATH, idx)
    val_label_file = "{}/val_subset{}.csv".format(LABEL_PATH, idx)

    train_df = pd.read_csv(train_label_file)
    val_df = pd.read_csv(val_label_file)

    copy_files(train_df, files, params.IMG_TRAIN_PATH)
    copy_files(val_df, files, params.IMG_VAL_PATH)

    if loader == '_no_aug':
      deepweeds = get_single_dataloader(batch)
    elif loader == '_aug1':
      deepweeds = get_DW_dataloaders1(batch)
    elif loader == '_aug2':
      deepweeds = get_DW_dataloaders2(batch)
    else:
      pass

    # Stats for the datasets. 
    train_dict, val_dict = {},{}
    train_dict, val_dict = get_dataloader_counts(deepweeds)

    print()
    print('[{}/{}] Train Class Distribution: {}'.format(idx + 1, 
                                            params.FOLDS, train_dict))

    print('[{}/{}] Val Class Distribution: {}'.format(idx + 1, 
                                            params.FOLDS, val_dict))
    print()
    torch_resnet50, input_size = init_model()
    optimizer, scheduler = get_parameters(torch_resnet50, features=True)

    best_model_wts, va, vl, ta, tl, best_acc = train(torch_resnet50, deepweeds, 
                                                      optimizer, scheduler, 2)
    saved_name = save_model(best_model_wts, loader)
    best_epoch_accs.append(best_acc) 
    make_plots(va, vl, ta, tl)

    # Assure that files are reset -----------------------------------
    assert len(os.listdir(params.IMG_TRAIN_PATH + '/0')) != 0
    delete_train_val_files(params.IMG_TRAIN_PATH)
    assert len(os.listdir(params.IMG_TRAIN_PATH + '/0')) == 0
    delete_train_val_files(params.IMG_VAL_PATH)
    assert len(os.listdir(params.IMG_VAL_PATH + '/0')) == 0

  return best_epoch_accs, saved_name

### Run K-folds. 

In [None]:
# ------------ RUN TRAIN/EVAL ------------------------------------
batch = 32
best_fold_accs, saved_name = run_train_kfold('_no_aug', batch)
# ----------------------------------------------------------------

### Save the current model on Google Drive

In [None]:
def copy_pth_to_gdrive(path_to_model):
  '''Get the name of the model created above and save the current trained model to G Drive.
  (This is optional)
  '''
  shutil.copy(path_to_model, params.GD_WRITE_DIR)
  
copy_pth_to_gdrive(saved_name)

## Steps For Test
##### 0) Get files in order (mainly Colab-specific).

    1) Copy files to single directory (not using ImageFolder).
    2) Instantiate data loaders.
    3) Load the trained ResNet50 model.
    4) Test the model. Save results.
    5) Delete contents of the test directory.
    7) REPEAT 1-5.

### Functions for testing.

In [None]:
def return_last_pth():
  '''Load the last model saved
  '''
  filename = max([f for f in pathlib.Path(params.OUTPUT_PATH).glob('*_model.pth')],
                key=os.path.getctime)
  return filename

def print_states(m):
  # Print the model's state_dict
  print("Model's state_dict:")
  for param_tensor in m.state_dict():
    print(param_tensor, "\t", m.state_dict()[param_tensor].size())

def load_model(name):
  '''Load a trained model for testing. 
  '''
  model, input_size = init_model()
  model.load_state_dict(torch.load(name))
  return model

### Function to test the best model.

In [None]:
import numpy as np
import pandas as pd
import sklearn.metrics as sm

'''
    .                          .
  .o8                        .o8
888oo  .ooooo.   .oooo.o .o888oo
  888   d88' `88b d88(  "8   888
  888   888ooo888 `"Y88b.    888
  888 . 888    .o o.  )88b   888 .
  "888" `Y8bod8P' 8""888P'   "888"
'''

def test(test_loader, model):
  model.eval()
  correct = 0
  targets, preds = [], []

  with torch.no_grad():
    for data, target in test_loader:
      if torch.cuda.is_available():
        data = data.cuda()
        target = target.cuda()
      output = model(data)
      pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
      correct += pred.eq(target.view_as(pred)).sum().item()

      targets += list(target.cpu().numpy())
      preds += list(pred.cpu().numpy())
  
  test_acc = 100. * correct / len(test_loader.dataset)
  confusion_mtx = sm.confusion_matrix(targets, preds)
  return test_acc, confusion_mtx

### Run k-folds in testing also.

In [None]:
def run_test_kfold(model):
  ''' Runs 5 folds (set in params file).
  '''
  
  metrics = {}

  for idx in range(params.FOLDS):
    test_label_file = "{}/test_subset{}.csv".format(LABEL_PATH, idx)
    test_df = pd.read_csv(test_label_file)
    copy_test_files(test_df, params.IMG_TEST_PATH)

    test_dataset = DeepWeeds_Test(test_label_file)
    test_loader  = DataLoader(test_dataset, 
      batch_size=params.BATCH_SIZE, shuffle=False,
      pin_memory=torch.cuda.is_available(), 
      num_workers=2)

    # --- Get metrics for each fold.
    metrics[idx] = test(test_loader, model)
    # ---

    delete_test_files(params.IMG_TEST_PATH)
    cnt = len([name for name in os.listdir(params.IMG_TEST_PATH) \
              if os.path.isfile(os.path.join(params.IMG_TEST_PATH, name))])
    assert cnt == 0
  
  return metrics

### Test the best model.

In [None]:
model = load_model(saved_name)
results = run_test_kfold(model)

### Save the results of testing and plot a confusion matrix

In [None]:
stamp = saved_name.split('_')[0].split('/')[1]
pickle_name = 'results_'+stamp
pickler(results, pickle_name)

# Pick out the best of five confusion matrixes, i.e. cm = results[2][1]
cmd = sm.ConfusionMatrixDisplay(results[0][1], display_labels=['0','1','2','3','4','5','6','7','8'])
cmd.plot()

In [None]:
# Calculate other metrics. Note: From here down was included AFTER handing in the assignment! 
file = '/content/torch-draft-final_project-main/output/results_20220426-184852.pickle'
result = unpickler(file)
mtx = result[4][1]

FP = mtx.sum(axis=0) - np.diag(mtx) 
FN = mtx.sum(axis=1) - np.diag(mtx)
TP = np.diag(mtx)
TN = mtx.sum() - (FP + FN + TP)
FP = FP.astype(float)
FN = FN.astype(float)
TP = TP.astype(float)
TN = TN.astype(float)

# Sensitivity, hit rate, recall, or true positive rate
TPR = TP/(TP+FN)
# Specificity or true negative rate
TNR = TN/(TN+FP) 
# Precision or positive predictive value
PPV = TP/(TP+FP)
# Negative predictive value
NPV = TN/(TN+FN)
# Fall out or false positive rate
FPR = FP/(FP+TN)
# False negative rate
FNR = FN/(TP+FN)
# False discovery rate
FDR = FP/(TP+FP)
# Overall accuracy for each class
ACC = (TP+TN)/(TP+FP+FN+TN)

#df = pd.DataFrame({'count': {'0':0.8,'1':0.85377358,'2':0.98543689,'3':0.89705882,
#                             '4':0.94811321,'5':0.9427363,'6':0.91588785,'7':0.89655172,
#                             '8':0.94618342 }}).reset_index()
#plt.bar(range(len(df)), df["count"], color=plt.cm.Paired(np.arange(len(df))))

#None

In [None]:
def make_bar_plot(metric):
  ''' Plot Sensitivity/Recall
  '''
  keys = ['0', '1', '2', '3', '4', '5', '6', '7', '8']
  values = [round(x, 2) for x in metric.tolist()]
  test_dict = dict(zip(keys, values))

  df = pd.DataFrame({'count': test_dict}).reset_index()
  plt.bar(range(len(df)), df["count"], color=plt.cm.Paired(np.arange(len(df))))
  None
  return test_dict

test_dict = make_bar_plot(TPR)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme()
sns.histplot(label_df['Label'], bins=8, discrete=True, color="lightcoral")
None

##### fin