In [None]:
import torch
import torchvision
from torch import nn 
from torchvision import transforms
import matplotlib.pyplot as plt
import pandas as pd
from IPython.display import Image 
import numpy as np
from scipy.io import loadmat
from scipy.io import savemat 
import dataset_utils

In [None]:
def get_imgs_bold_id(image_dataset,df):
  img2dna = dict()
  not_found_images = []
  for i, row in df.iterrows():
        url = row['image_urls'].split('|')[0]
        genus_name = row['genus_name'].replace(' ','_')
        image_name_csv ='image_dataset/' + genus_name + '/' + url[url.rfind('/')+1:]
        trovato = False
        for img in image_dataset.imgs:
            if img[0] == image_name_csv:
                img2dna[img[0]]= row['processid']
                trovato = True
                break
        if not trovato:
            not_found_images.append(image_name_csv)
  return img2dna

In [None]:
df = pd.read_csv('unknown_species_new_samples.csv',index_col=0)
tform = transforms.Compose([transforms.Resize((64,64)),transforms.PILToTensor(),transforms.ConvertImageDtype(torch.float),transforms.Normalize(0.5,0.5)])
image_dataset = torchvision.datasets.ImageFolder("image_dataset/",transform=tform)

batch_size = 1000 

import random
import dataset_utils
import pickle

with open('genusname2genuslabel.pickle', 'rb') as handle:
    genusname2genuslabel = pickle.load(handle)

for i, imgpath, specieslabel in enumerate(image_dataset.imgs):
    imgpath = imgpath.replace("image_dataset/","")
    image_dataset.imgs[i][1] = genusname2genuslabel[imgpath[:imgpath.index("")]]
    
img2dna = dataset_utils.get_imgs_bold_id(image_dataset,df)

nucleotides = df[['nucleotide','species_name','genus_name','processid','image_urls']]

colonna_dna = df.loc[:,"nucleotide"]
#nucleotides.loc[:,'nucleotide'] = colonna_dna.apply(dataset_utils.one_hot_encoding)
nucleotides['string_nucleotides'] = nucleotides['nucleotide']
nucleotides.loc[:,'nucleotide'] = colonna_dna.apply(dataset_utils.one_hot_encoding)
random.seed(42)


In [None]:
img2dna_indices = dict()
for k,v in img2dna.items():
    #print(k)
    #print(v)
    dna_index = np.where(nucleotides['processid'].values == v)
    if dna_index[0].size > 0:
        #print(dna_index)
        dna_index = dna_index[0][0]
        
        for i,(name,_) in enumerate(image_dataset.imgs):
            if name == k:
                image_index = i
                break
        img2dna_indices[image_index] = dna_index
    else:
        raise Exception

In [None]:

all_not_expanded_one_hots = nucleotides['nucleotide'].to_numpy()
all_not_expanded_string_dnas= nucleotides['string_nucleotides'].to_numpy()
all_not_expanded_one_hots.shape

In [None]:
all_dnas = []
all_string_dnas = []
all_dna_labels = []
already_seen_dna_indices = set()
is_first_occurrence = []
for i in range(len(image_dataset.imgs)):
    all_dnas.append(torch.tensor(all_not_expanded_one_hots[img2dna_indices[i]]))
    all_string_dnas.append(all_not_expanded_string_dnas[img2dna_indices[i]])
    all_dna_labels.append(torch.tensor(image_dataset.imgs[i][1]))
    if img2dna_indices[i] not in already_seen_dna_indices:
        is_first_occurrence.append(True)
        already_seen_dna_indices.add(img2dna_indices[i])
    else:
        is_first_occurrence.append(False)
all_dnas = torch.stack(all_dnas)
all_dna_labels = torch.stack(all_dna_labels)
all_string_dnas = np.array(all_string_dnas)

In [None]:
from torch.utils.data import Dataset, DataLoader
batch_size = 1000
class WholeDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.targets = data.targets#torch.tensor(targets)
        #self.transform = transform
        
    def __getitem__(self, index):
        x = self.data[index][0]
        y = self.targets[index]
        
        return x, y
    
    def __len__(self):
        return len(self.data)
        
whole_dataset = WholeDataset(image_dataset)
n_classes = np.unique(whole_dataset.targets).shape[0]
whole_loader = torch.utils.data.DataLoader(whole_dataset, batch_size=batch_size,shuffle=False, num_workers=2)


batch_images_list = []
batch_image_labels_list = []
with torch.no_grad():
    for batch, targets in whole_loader:
        batch_images_list.append(batch.numpy())
        batch_image_labels_list.append(targets.numpy()) 
all_images= np.concatenate(batch_images_list)
all_labels= np.concatenate(batch_image_labels_list)

In [None]:
boldids = dataset_utils.image_filenames_from_df(df)

In [None]:
all_dataset = dict()
all_dataset['all_images'] = all_images 
all_dataset['all_dnas'] = all_dnas.numpy()
all_dataset['all_string_dnas'] = all_string_dnas
all_dataset['all_labels'] = (all_labels+1)
all_dataset['all_boldids']= np.array(boldids)
savemat('matlab_dataset/insect_dataset.mat',all_dataset)