In [None]:
import torch
from torchvision import datasets, models, transforms
from google.colab import drive
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, roc_auc_score
from tqdm import tqdm
import time
import datetime
from google.colab import files
import seaborn as sns
import numpy as np 
import pandas as pd 
import cv2
from torch import nn
from torch import optim
import torch.nn.functional as F
from pathlib import Path
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
import os


drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
data_dir = '/content/drive/My Drive/competitions_egor/zindi_invertebrates/'
train_dir = data_dir + 'train'
valid_dir = data_dir + 'val'

In [None]:
# Transformation of datasets
training_transforms = transforms.Compose([transforms.RandomRotation(30),
                                          transforms.RandomResizedCrop(224),
                                          transforms.RandomHorizontalFlip(),
                                          transforms.ToTensor(),
                                          transforms.Normalize([0.485, 0.456, 0.406], 
                                                               [0.229, 0.224, 0.225])])

validation_transforms = transforms.Compose([transforms.Resize(256),
                                            transforms.CenterCrop(224),
                                            transforms.ToTensor(),
                                            transforms.Normalize([0.485, 0.456, 0.406], 
                                                                 [0.229, 0.224, 0.225])])

training_dataset = datasets.ImageFolder(train_dir, transform=training_transforms)
validation_dataset = datasets.ImageFolder(valid_dir, transform=validation_transforms)

train_loader = torch.utils.data.DataLoader(training_dataset, batch_size=64, shuffle=True)
validate_loader = torch.utils.data.DataLoader(validation_dataset, batch_size=32)

In [None]:
#use pretrained model
model = models.resnet18(pretrained=True)
for param in model.parameters():
    param.requires_grad = False
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 137)

In [None]:
#validation
def validation(model, validateloader, criterion):

    val_loss = 0
    accuracy = 0
    
    for images, labels in iter(validateloader):
        images, labels = images.to('cuda'), labels.to('cuda')
        output = model.forward(images)
        val_loss += criterion(output, labels).item()
        probabilities = torch.exp(output)
        equality = (labels.data == probabilities.max(dim=1)[1])
        accuracy += equality.type(torch.FloatTensor).mean()
    return val_loss, accuracy

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=0.01)

In [None]:
# Train the classifier
def train_classifier():

    epochs = 5
    steps = 0
    print_every = 43
    model.to('cuda')
 
    for e in range(epochs):
        model.train()
        running_loss = 0
        for images, labels in iter(train_loader):
            steps += 1
            images, labels = images.to('cuda'), labels.to('cuda')
            optimizer.zero_grad()
            output = model.forward(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            if steps % print_every == 0:
                model.eval()
                with torch.no_grad():
                    validation_loss, accuracy = validation(model, validate_loader, criterion)
                print("Epoch: {}/{}.. ".format(e+1, epochs),
                        "Training Loss: {:.3f}.. ".format(running_loss/print_every),
                        "Validation Loss: {:.3f}.. ".format(validation_loss/len(validate_loader)),
                        "Validation Accuracy: {:.3f}".format(accuracy/len(validate_loader)))
                running_loss = 0
                model.train()
                    
train_classifier()

In [None]:
#transform test dataset
test_dir = data_dir + 'test'
testing_transforms = transforms.Compose([transforms.Resize(256),
                                         transforms.CenterCrop(224),
                                         transforms.ToTensor(),
                                         transforms.Normalize([0.485, 0.456, 0.406], 
                                                              [0.229, 0.224, 0.225])])

testing_dataset = datasets.ImageFolder(test_dir, transform=testing_transforms)
classes = train_loader.dataset.classes
images = [name for name in os.listdir(test_dir + '/test_small')]
test_loader = torch.utils.data.DataLoader(testing_dataset, batch_size=1)

In [None]:
#create sybmission
submission = pd.DataFrame(columns=classes,index=range(1423))
submission.head()

Unnamed: 0,Actiniaria,Actinoptilum_molle,Actinoscyphia_plebeia,Actinostola_capensis,Aequorea_spp,Africolaria_rutila,Alcyonacea,Amalda_bullioides,Anthoptilum_grandiflorum,Aphelodoris_sp_,Aphrodita_alta,Aristeus_varidens,Armina_sp_,Ascidiacea,Astropecten_irregularis_pontoporeus,Athleta_abyssicola,Athleta_lutosa,Bolocera_kerguelensis,Brissopsis_lyrifera_capensis,Bryozoa,Cavernularia_spp,Cephalodiscus_gilchristi,Ceramaster_patagonicus_euryplax,Charonia_lampas,Cheilostomatida,Cheiraster_hirsutus,Chondraster_elattosis,Chrysaora_fulgida,Chrysaora_spp,Comanthus_wahlbergii,Comitas_saldanhae,Comitas_stolida,Cosmasterias_felipes,Crossaster_penicillatus,Cypraeovula_iutsui,Diplopteraster_multipes,Dipsacaster_sladeni_capensis,Echinus_gilchristi,Eleutherobia_variable,Euspira_napus,...,Philinopsis_capensis,Phormosoma_placenta_africana,Plesionika_martia,Pleurobranchaea_bubala,Polychaete_tubes_(only),Polychaete_worms,Polyechinus_agulhensis,Poraniopsis_echinaster,Porifera,Prawns,Projasus_parkeri,Pseudarchaster_brachyactis,Pseudarchaster_tessellatus,Pseudodromia_rotunda,Pseudodromia_spp_,Pseudostichopus_langeae,Psilaster_acuminatus,Pteraster_capensis,Pterygosquilla_capensis,Pycnogonid_spp_,Pyromaia_tuberculata,Rochinia_hertwigi,Rossella_antarctica,Salpa_spp_,Scaphander_punctostriatus,Scleractinia,Sclerasterias_spp,Seafan,Solenocera_africana,Spatangus_capensis,Stereomastis_sculpta,Stylasteridae,Suberites_dandelenae,Sympagurus_dimorphus,Synallactes_viridilimus,Terebratulina_sp_,Toraster_tuberculatus,Triviella_spp_,Turritella_declivis,Vitjazmaia_latidactyla
0,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
1,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
2,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
3,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
4,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,


In [None]:
#predict function
def predict(model, test_loader):

    model.eval()
    model.to('cuda')
    sm = torch.nn.Softmax(dim=0)
    with torch.no_grad():
        for images, labels in iter(test_loader):
            images, labels = images.to('cuda'), labels.to('cuda')
            output = sm(model.forward(images)[0]).cpu().detach().numpy()
            submission.loc[i] = output
        
predict(model, test_loader)

In [None]:
submission.head()

Unnamed: 0,Actiniaria,Actinoptilum_molle,Actinoscyphia_plebeia,Actinostola_capensis,Aequorea_spp,Africolaria_rutila,Alcyonacea,Amalda_bullioides,Anthoptilum_grandiflorum,Aphelodoris_sp_,Aphrodita_alta,Aristeus_varidens,Armina_sp_,Ascidiacea,Astropecten_irregularis_pontoporeus,Athleta_abyssicola,Athleta_lutosa,Bolocera_kerguelensis,Brissopsis_lyrifera_capensis,Bryozoa,Cavernularia_spp,Cephalodiscus_gilchristi,Ceramaster_patagonicus_euryplax,Charonia_lampas,Cheilostomatida,Cheiraster_hirsutus,Chondraster_elattosis,Chrysaora_fulgida,Chrysaora_spp,Comanthus_wahlbergii,Comitas_saldanhae,Comitas_stolida,Cosmasterias_felipes,Crossaster_penicillatus,Cypraeovula_iutsui,Diplopteraster_multipes,Dipsacaster_sladeni_capensis,Echinus_gilchristi,Eleutherobia_variable,Euspira_napus,...,Philinopsis_capensis,Phormosoma_placenta_africana,Plesionika_martia,Pleurobranchaea_bubala,Polychaete_tubes_(only),Polychaete_worms,Polyechinus_agulhensis,Poraniopsis_echinaster,Porifera,Prawns,Projasus_parkeri,Pseudarchaster_brachyactis,Pseudarchaster_tessellatus,Pseudodromia_rotunda,Pseudodromia_spp_,Pseudostichopus_langeae,Psilaster_acuminatus,Pteraster_capensis,Pterygosquilla_capensis,Pycnogonid_spp_,Pyromaia_tuberculata,Rochinia_hertwigi,Rossella_antarctica,Salpa_spp_,Scaphander_punctostriatus,Scleractinia,Sclerasterias_spp,Seafan,Solenocera_africana,Spatangus_capensis,Stereomastis_sculpta,Stylasteridae,Suberites_dandelenae,Sympagurus_dimorphus,Synallactes_viridilimus,Terebratulina_sp_,Toraster_tuberculatus,Triviella_spp_,Turritella_declivis,Vitjazmaia_latidactyla
0,0.00761806,8.3677e-06,0.000171063,0.000615527,2.44469e-07,8.61868e-06,1.36489e-05,0.241216,6.05073e-05,5.09163e-05,0.000142639,0.000513216,0.000506831,4.77001e-06,7.14456e-06,0.000582657,0.00139735,0.00273209,0.000291134,6.1455e-07,5.94618e-06,1.54401e-07,1.06456e-05,1.78973e-07,8.22154e-07,2.51389e-07,0.000365128,3.44383e-05,2.79416e-05,4.22926e-08,0.00194865,4.43267e-05,4.5209e-07,9.6787e-07,0.000105386,2.53453e-07,9.05108e-09,3.25574e-06,1.17827e-06,0.000531298,...,1.97071e-06,2.0345e-07,1.56393e-05,0.00171841,5.74952e-06,1.22157e-05,2.5413e-08,3.59333e-05,1.95123e-05,1.95637e-06,1.11408e-06,0.000500309,4.1037e-06,0.000611128,1.67444e-05,2.8471e-05,1.4655e-07,3.19632e-05,1.72298e-06,2.37014e-05,0.000161906,3.76347e-05,4.5518e-05,0.00194698,0.0126144,0.000336538,1.43866e-07,2.63267e-05,2.98961e-05,5.60845e-07,2.7133e-07,1.28589e-05,1.07566e-05,3.79129e-06,2.32312e-09,0.000162822,0.000172442,0.0640926,0.00068517,1.35106e-06
1,1.4058e-05,5.13251e-07,1.23506e-05,1.19531e-06,1.10138e-06,1.97224e-06,5.73024e-08,1.86944e-08,0.000373773,2.34047e-10,1.13994e-05,4.0486e-05,2.77715e-07,1.88722e-06,6.64554e-06,7.83142e-08,1.1356e-08,9.86968e-06,1.40541e-07,0.00445235,0.000116777,6.66136e-09,7.96079e-06,4.01169e-09,0.0026496,1.96014e-08,1.68258e-07,7.85347e-06,4.74691e-06,0.00204836,1.61314e-05,3.61964e-08,1.64504e-08,3.14698e-07,2.4406e-09,3.56383e-07,2.86102e-09,7.64114e-07,1.00009e-08,3.17633e-10,...,1.63029e-09,3.48006e-08,8.63779e-07,3.86229e-08,0.000136742,2.61252e-06,1.47654e-07,3.57297e-05,0.0365339,9.8849e-07,1.10245e-06,2.59418e-07,1.10851e-06,1.54951e-09,2.5056e-09,6.80726e-07,6.78123e-07,3.38555e-06,5.93362e-07,0.00171197,1.55726e-05,0.0010876,1.51324e-07,1.03768e-05,7.25239e-07,0.00199825,2.2416e-06,0.00152429,2.79971e-07,8.44678e-08,4.15018e-09,0.939542,6.4544e-06,6.53344e-09,4.86768e-09,9.21033e-08,4.86093e-06,1.65956e-08,5.87372e-09,7.67211e-07
2,1.45311e-06,9.51664e-05,2.35913e-09,3.23585e-09,8.70385e-07,2.68635e-08,3.26786e-05,3.51983e-07,0.000312424,1.11507e-10,3.65743e-05,0.000125327,7.29114e-07,1.30812e-05,6.63532e-05,1.20965e-05,1.27794e-07,4.58657e-07,9.75328e-07,3.19429e-05,2.61442e-07,5.07225e-06,2.49127e-07,4.39189e-08,6.38772e-05,4.98744e-07,3.7598e-07,9.58327e-08,4.06329e-09,0.0220202,7.42156e-05,2.7405e-05,4.67956e-06,3.87642e-05,3.89474e-07,1.41768e-08,2.83357e-08,8.64256e-06,4.01327e-08,4.06403e-07,...,5.41044e-08,5.78608e-08,4.09534e-05,2.07728e-08,0.310995,0.00112211,1.0752e-06,0.000294207,0.000738703,9.45737e-06,0.00523025,3.45251e-07,2.11549e-07,3.71439e-05,1.54845e-07,7.96896e-07,6.62922e-08,4.69749e-06,1.67304e-06,0.045054,0.536115,0.0444247,1.54473e-06,8.7294e-06,1.79772e-07,0.000287147,0.000272196,0.000152589,0.000192455,1.34849e-06,1.3067e-07,1.18087e-05,1.17551e-08,8.48726e-05,8.13392e-10,3.95563e-07,3.83543e-06,3.33152e-06,8.77016e-05,0.000108379
3,2.84563e-06,6.71574e-07,1.95635e-05,2.42628e-05,9.57552e-08,8.38095e-06,0.000291813,4.62461e-08,0.000152591,3.32929e-09,0.000562663,3.33297e-06,1.93235e-07,1.00176e-05,4.64972e-06,6.17098e-07,8.538e-08,1.46652e-05,2.147e-06,2.62249e-06,5.83823e-06,1.88949e-06,8.79257e-09,5.58141e-08,2.48307e-06,2.62314e-07,1.08998e-06,1.0194e-05,1.31185e-06,4.11055e-07,6.63586e-06,5.08383e-08,4.26245e-09,2.66654e-06,4.5372e-08,0.000655102,1.13824e-09,0.000675804,2.95295e-08,1.49119e-07,...,5.21274e-08,1.53139e-06,3.79451e-08,6.69608e-06,5.21997e-05,4.99692e-08,3.79499e-06,0.00961626,0.139589,1.24107e-08,9.10349e-07,4.26342e-07,9.2548e-08,1.91016e-06,3.51232e-06,2.12902e-07,5.87689e-11,0.000236134,3.99698e-07,1.48598e-07,9.3558e-05,0.000185362,0.0119776,0.000137501,9.83909e-07,0.000125342,1.27861e-06,1.80276e-06,4.23025e-06,4.82954e-07,6.40119e-06,4.75907e-06,0.000274857,8.35424e-09,4.2149e-07,7.26237e-06,1.08578e-07,4.26289e-07,6.514e-07,5.47633e-06
4,6.99263e-06,3.3917e-09,3.36919e-06,5.87307e-07,2.89714e-08,1.36581e-05,0.00168633,1.67622e-08,3.3368e-05,3.75249e-08,4.98332e-05,4.41717e-06,2.60435e-07,2.26331e-05,1.84161e-07,1.5689e-07,4.52515e-07,1.02057e-05,3.57822e-06,3.46873e-07,3.44328e-07,3.85115e-09,8.67046e-09,2.68914e-08,5.49964e-06,1.61481e-08,3.30167e-05,5.99859e-07,5.74086e-07,1.26375e-08,8.66894e-05,5.34622e-09,4.47039e-08,1.47488e-07,1.41167e-07,8.62617e-08,1.03846e-09,1.6295e-07,5.30383e-09,6.49073e-08,...,1.1806e-07,9.40867e-08,1.28493e-07,1.45007e-06,8.03546e-06,4.84525e-08,3.09118e-09,5.65584e-07,0.00585125,1.16836e-07,2.7632e-07,5.36847e-07,3.7205e-08,3.00387e-06,4.06968e-07,3.76163e-07,2.58612e-10,8.36891e-06,9.26248e-08,4.21876e-06,4.01973e-06,2.98855e-06,0.00179502,0.000134762,1.86483e-06,1.7256e-05,5.72346e-08,6.18244e-07,9.1568e-08,6.66972e-07,3.68035e-09,6.44366e-07,0.989982,1.65416e-08,6.14605e-09,2.39001e-07,9.80095e-06,2.43915e-08,3.86346e-06,1.53687e-05


In [None]:
sub=pd.read_csv(data_dir+'SampleSubmission.csv')
k = sub.columns.tolist()
k.pop(0)
#set right order of classes
submission = submission[k]
submission.insert(0, "FILE", images)
submission.to_csv('sub.csv', index=False)
files.download("sub.csv")

           FILE Pteraster_capensis  ... Triviella_spp_ Cheilostomatida
0  0RJNKE6.jpeg        3.19632e-05  ...      0.0640926     8.22154e-07
1  0RYI2RB.jpeg        3.38555e-06  ...    1.65956e-08       0.0026496
2  0S9QV3D.jpeg        4.69749e-06  ...    3.33152e-06     6.38772e-05
3  0S89A5B.jpeg        0.000236134  ...    4.26289e-07     2.48307e-06
4  0RASRVM.jpeg        8.36891e-06  ...    2.43915e-08     5.49964e-06

[5 rows x 138 columns]


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>