In [1]:
#MatPlotLib
from matplotlib import pyplot as plt
from matplotlib.ticker import ScalarFormatter
%matplotlib inline
plt.rcParams["figure.figsize"] = [9.6, 7.2]

#Numpy
import numpy as np

#Pandas
import pandas as pd

#Pickle
import pickle

#Pytorch
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.utils.model_zoo as model_zoo

#Scipy
from scipy.stats import pearsonr, spearmanr, kendalltau, boxcox

#Sklearn
from sklearn.metrics import confusion_matrix, accuracy_score, balanced_accuracy_score, f1_score, matthews_corrcoef, cohen_kappa_score
from sklearn.model_selection import train_test_split
from sklearn.utils import class_weight
import warnings
#warnings.filterwarnings('ignore') 

#Torchvision for CV
import torchvision

#Others
import os
from PIL import Image
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [23]:
import random

class ObservationsDataset(torch.utils.data.Dataset):
  def __init__(self, transform, window_size=4):
    self.basePath = "/content/drive/MyDrive/Datasets/2022/trusted/images/"
    self.fileName = "/content/drive/MyDrive/Datasets/2022/trusted/PlantCLEF2022_trusted_training_metadata.csv"
    self.transform = transform
    self.window_size = window_size
   
    df = pd.read_csv(self.fileName, sep=';', usecols=["classid", "image_path", "gbif_occurrence_id"])
    df["Index"] = df.index
    df.gbif_occurrence_id.fillna(df.Index, inplace=True)

    all_classes = list(df.classid.unique())
    self.class_dict = {k: v for v, k in enumerate(all_classes)}
    self.inv_class_dict = { v: k for v, k in enumerate(all_classes)}

    self.obs = df.groupby('gbif_occurrence_id')
    self.restart_windows()

  def restart_windows(self):
    self.observations = []
    self.targets = []
    self.obs.apply(lambda ob: self.ob(ob)) 

  def ob(self, images):
    classid = images.iloc[0]["classid"]
    image_paths = images["image_path"].tolist()
    random.shuffle(image_paths)
    windows = [image_paths[x:x+self.window_size] for x in range(0, len(image_paths), self.window_size)]
    self.observations += windows
    self.targets += [self.class_dict[classid]] * len(windows)

  def __len__(self):
    return len(self.observations)
    
  def __getitem__(self, index):
    ob = self.observations[index]
    window_count = len(ob)
    window = []
    for i in range(self.window_size):
      if i < window_count:
        full_path = os.path.join(self.basePath, ob[i])
        image = Image.open(full_path).convert('RGB')
        image = self.transform(image)
        window.append(image)
      else:
        window.append(torch.zeros(3,224,224))

    window = torch.cat(window, dim=1)
    return window, self.targets[index]

In [24]:
transform = torchvision.transforms.Compose([
                                         torchvision.transforms.Resize(256),
                                         torchvision.transforms.CenterCrop(224),
                                         torchvision.transforms.ToTensor(),
                                         torchvision.transforms.Normalize([0.3734], [0.2041])])
d2 = ObservationsDataset(transform)

In [25]:
d2.restart_windows()