In [1]:
!pip install timm # install pytorch image models
!pip install torchmetrics

Collecting timm
  Downloading timm-0.5.4-py3-none-any.whl (431 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m431.5/431.5 kB[0m [31m9.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: timm
Successfully installed timm-0.5.4
[0m

In [2]:
import torch
import pandas as pd
import torchvision.models as models
import timm
import albumentations as A
import cv2
import numpy as np
import tensorflow as tf

from torch import nn
from  torch.cuda.amp import autocast, GradScaler
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from torchvision import transforms

In [3]:
class CustomModel(torch.nn.Module): 
    def __init__(self, model_backbone):
        super(CustomModel,self).__init__()
        self.model = model_backbone
        self.num_in_features = self.model.get_classifier().in_features
        print(self.num_in_features)
        self.model.classifier = nn.Sequential(
            nn.BatchNorm1d(self.num_in_features),
            nn.Linear(self.num_in_features, 512),
            nn.Dropout(0.5),
            nn.ReLU(inplace=True),
            nn.Linear(512, 100),
        )
    def forward(self,x):
        x = self.model(x)
        return x

In [4]:
class SorghumDataset(Dataset):
    def __init__(self, dirs, labels, transformation=None):
        super(SorghumDataset,self).__init__()
        self.dirs = dirs
        self.labels = labels
        self.transformation = transformation
    def __len__(self):
        return len(self.dirs)

    def __getitem__(self, index):
        image = cv2.imread(self.dirs[index])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        label = self.labels[index] # need to one hot encoding here
        
        image = np.array(image)

        if self.transformation:
            aug_image = self.transformation(image=image)
            image = aug_image['image']
            
        image = image / 255.
        image = image.transpose((2, 0, 1))
        
        image = torch.from_numpy(image).type(torch.float32)
        image = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))(image)
        
        labels = torch.from_numpy(np.array(self.labels[index])).type(torch.float32)


        return image, labels

# Resnet

In [5]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model_name = 'resnext50d_32x4d'

backbone = timm.create_model(model_name,pretrained=True)
model = CustomModel(backbone)

model.to(device)

Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnext50d_32x4d-103e99f8.pth" to /root/.cache/torch/hub/checkpoints/resnext50d_32x4d-103e99f8.pth


2048


CustomModel(
  (model): ResNet(
    (conv1): Sequential(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
      (6): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act1): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stat

In [6]:
checkpoint = torch.load('../input/sorghum-efficientnetv2-0-846-private-lb/resnext50d_32x4d_best.pt')
model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [7]:
sub = pd.read_csv('../input/sorghum-id-fgvc-9/sample_submission.csv')
sub.head()

Unnamed: 0,filename,cultivar
0,1000005362.png,PI_152923
1,1000099707.png,PI_152923
2,1000135300.png,PI_152923
3,1000136796.png,PI_152923
4,1000292439.png,PI_152923


In [8]:
sub["filename"] = sub["filename"].apply(lambda image: '../input/sorghum-id-fgvc-9/test/' + image)
sub["cultivar"] = 0
sub.head()

Unnamed: 0,filename,cultivar
0,../input/sorghum-id-fgvc-9/test/1000005362.png,0
1,../input/sorghum-id-fgvc-9/test/1000099707.png,0
2,../input/sorghum-id-fgvc-9/test/1000135300.png,0
3,../input/sorghum-id-fgvc-9/test/1000136796.png,0
4,../input/sorghum-id-fgvc-9/test/1000292439.png,0


In [9]:
validation_transformation = A.Compose([
    A.Resize(width=512, height=512, p=1.0)
])

testing_dataset = SorghumDataset(sub['filename'], sub['cultivar'], validation_transformation)
testing_dataloader = DataLoader(testing_dataset, 
                                batch_size=32, 
                                shuffle=False, 
                                num_workers=1)

In [10]:
predictions = []
cnt = 0

resnet_preds = []

with torch.no_grad():
    for image, label in tqdm(testing_dataloader):
        image = image.to(device)
        outputs = model(image)
        for i in range(len(outputs)):
            resnet_preds.append(outputs[i][:100])
#         resnet_preds.append(outputs[0][:100])
        preds = outputs.detach().cpu()
        predictions.append(preds.argmax(1))

100%|██████████| 739/739 [26:22<00:00,  2.14s/it]


In [11]:
resnet_preds_ = []

for i in range(len(resnet_preds)):
    resnet_preds_.append(resnet_preds[i].tolist())

In [12]:
resnet_preds_list = pd.DataFrame({'0_class':[]})

for i in range(99, 0, -1): 
    resnet_preds_list.insert(1, "{}_class".format(i), [])

for i in range(len(resnet_preds_)):
    resnet_preds_list.loc[i] = resnet_preds_[i]

In [13]:
resnet_preds_list.to_csv('resnet_submission.csv', index=False)

resnet_preds_list.head()

Unnamed: 0,0_class,1_class,2_class,3_class,4_class,5_class,6_class,7_class,8_class,9_class,...,90_class,91_class,92_class,93_class,94_class,95_class,96_class,97_class,98_class,99_class
0,-4.801618,-2.975569,-1.0109,0.004538,0.350488,-4.283353,-4.044071,-1.81717,-3.906837,-5.185993,...,-4.440704,-5.510637,-3.565371,-4.36757,-3.586168,-4.021678,-5.994962,-4.712767,-2.815238,-1.952368
1,-3.122581,-1.158474,-8.431793,-7.456934,-7.085815,-5.844799,-4.357646,-6.926689,-6.120798,0.864344,...,-1.683628,-4.94246,-5.104192,-3.478405,-3.778158,-4.185591,-8.658447,-5.647788,-5.539834,-7.571363
2,-5.502335,-1.471846,8.994921,-1.64459,2.780464,-4.725876,-4.252627,-5.074745,0.530009,-0.46011,...,-3.465601,-5.762954,-5.205477,-4.648644,-1.996166,-7.006385,-6.289888,-5.483791,-3.666237,-3.386745
3,-1.438631,-2.900598,-2.842964,-2.837878,-2.651607,-3.997262,-2.419579,-2.409816,-1.67164,-2.249619,...,-2.150546,-3.092925,-1.705112,-3.123335,-2.864271,-2.119378,-2.402761,-2.806396,-3.288818,-2.080209
4,-5.85615,-0.279349,-4.226741,-2.668807,-2.526486,-1.11652,-0.516609,-5.410573,-4.07262,-1.866208,...,-2.197257,-7.641562,-1.674403,-4.825066,-5.543388,-5.794561,-5.393475,-6.552042,0.802193,-5.098389


# EfficientNetB5

In [14]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model_name = 'tf_efficientnetv2_m_in21k'

backbone = timm.create_model(model_name,pretrained=True)
B5_model = CustomModel(backbone)

B5_model.to(device)

Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m_21k-361418a2.pth" to /root/.cache/torch/hub/checkpoints/tf_efficientnetv2_m_21k-361418a2.pth


1280


CustomModel(
  (model): EfficientNet(
    (conv_stem): Conv2dSame(3, 24, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (bn1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    (act1): SiLU(inplace=True)
    (blocks): Sequential(
      (0): Sequential(
        (0): ConvBnAct(
          (conv): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
        )
        (1): ConvBnAct(
          (conv): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
        )
        (2): ConvBnAct(
          (conv): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(24, eps=0.001, momentum=0.1, aff

In [15]:
checkpoint = torch.load('../input/sorghum-identification-12345/tf_efficientnetv2_m_in21k_best.pt')
B5_model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [16]:
predictions = []
cnt = 0

B5_preds = []

with torch.no_grad():
    for image, label in tqdm(testing_dataloader):
        image = image.to(device)
        outputs = model(image)
        for i in range(len(outputs)):
            B5_preds.append(outputs[i][:100])
#         resnet_preds.append(outputs[0][:100])
        preds = outputs.detach().cpu()
        predictions.append(preds.argmax(1))

100%|██████████| 739/739 [25:24<00:00,  2.06s/it]


In [17]:
B5_preds_ = []

for i in range(len(B5_preds)):
    B5_preds_.append(B5_preds[i].tolist())

In [18]:
B5_preds_list = pd.DataFrame({'0_class':[]})

for i in range(99, 0, -1): 
    B5_preds_list.insert(1, "{}_class".format(i), [])

for i in range(len(B5_preds_)):
    B5_preds_list.loc[i] = B5_preds_[i]

In [19]:
B5_preds_list

Unnamed: 0,0_class,1_class,2_class,3_class,4_class,5_class,6_class,7_class,8_class,9_class,...,90_class,91_class,92_class,93_class,94_class,95_class,96_class,97_class,98_class,99_class
0,-4.801618,-2.975569,-1.010900,0.004538,0.350488,-4.283353,-4.044071,-1.817170,-3.906837,-5.185993,...,-4.440704,-5.510637,-3.565371,-4.367570,-3.586168,-4.021678,-5.994962,-4.712767,-2.815238,-1.952368
1,-3.122581,-1.158474,-8.431793,-7.456934,-7.085815,-5.844799,-4.357646,-6.926689,-6.120798,0.864344,...,-1.683628,-4.942460,-5.104192,-3.478405,-3.778158,-4.185591,-8.658447,-5.647788,-5.539834,-7.571363
2,-5.502335,-1.471846,8.994921,-1.644590,2.780464,-4.725876,-4.252627,-5.074745,0.530009,-0.460110,...,-3.465601,-5.762954,-5.205477,-4.648644,-1.996166,-7.006385,-6.289888,-5.483791,-3.666237,-3.386745
3,-1.438631,-2.900598,-2.842964,-2.837878,-2.651607,-3.997262,-2.419579,-2.409816,-1.671640,-2.249619,...,-2.150546,-3.092925,-1.705112,-3.123335,-2.864271,-2.119378,-2.402761,-2.806396,-3.288818,-2.080209
4,-5.856150,-0.279349,-4.226741,-2.668807,-2.526486,-1.116520,-0.516609,-5.410573,-4.072620,-1.866208,...,-2.197257,-7.641562,-1.674403,-4.825066,-5.543388,-5.794561,-5.393475,-6.552042,0.802193,-5.098389
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
23634,-7.223331,-5.498235,-7.183830,-6.044033,-3.454030,-2.869962,-5.262556,-7.025408,-7.600406,-8.049458,...,-5.938259,-7.105239,-5.009425,-5.695383,-4.508875,-4.701503,-6.263385,-6.115704,-5.626700,-7.625852
23635,-4.395103,-5.223400,-2.850434,-0.915796,-1.466926,-4.560630,-1.761602,-4.885693,-3.868564,-7.511627,...,-5.109028,-6.600894,-5.732879,-3.651545,-4.017533,-4.707999,-6.996043,-2.584617,-3.975872,-3.213772
23636,-7.714730,-1.864583,-7.216725,-7.704274,-9.781572,-10.407336,-3.443444,-6.472396,-7.049388,-6.616426,...,-5.660957,-8.343438,-9.968068,-7.830461,3.343957,-7.972091,-7.800015,-5.601650,-6.300623,-10.347117
23637,-5.930204,-7.474098,-6.242562,-5.108381,-5.487892,0.793515,-6.607038,-3.528217,-6.645050,-8.464906,...,-3.553557,-5.071094,-7.368112,-6.799894,-3.683982,-5.894517,-4.045214,-7.671366,-3.703960,-5.831462


In [20]:
B5_preds_list.to_csv('B5_submission.csv', index=False)

B5_preds_list.head()

Unnamed: 0,0_class,1_class,2_class,3_class,4_class,5_class,6_class,7_class,8_class,9_class,...,90_class,91_class,92_class,93_class,94_class,95_class,96_class,97_class,98_class,99_class
0,-4.801618,-2.975569,-1.0109,0.004538,0.350488,-4.283353,-4.044071,-1.81717,-3.906837,-5.185993,...,-4.440704,-5.510637,-3.565371,-4.36757,-3.586168,-4.021678,-5.994962,-4.712767,-2.815238,-1.952368
1,-3.122581,-1.158474,-8.431793,-7.456934,-7.085815,-5.844799,-4.357646,-6.926689,-6.120798,0.864344,...,-1.683628,-4.94246,-5.104192,-3.478405,-3.778158,-4.185591,-8.658447,-5.647788,-5.539834,-7.571363
2,-5.502335,-1.471846,8.994921,-1.64459,2.780464,-4.725876,-4.252627,-5.074745,0.530009,-0.46011,...,-3.465601,-5.762954,-5.205477,-4.648644,-1.996166,-7.006385,-6.289888,-5.483791,-3.666237,-3.386745
3,-1.438631,-2.900598,-2.842964,-2.837878,-2.651607,-3.997262,-2.419579,-2.409816,-1.67164,-2.249619,...,-2.150546,-3.092925,-1.705112,-3.123335,-2.864271,-2.119378,-2.402761,-2.806396,-3.288818,-2.080209
4,-5.85615,-0.279349,-4.226741,-2.668807,-2.526486,-1.11652,-0.516609,-5.410573,-4.07262,-1.866208,...,-2.197257,-7.641562,-1.674403,-4.825066,-5.543388,-5.794561,-5.393475,-6.552042,0.802193,-5.098389


# EfficientNetB4

In [21]:
!pip install efficientnet_pytorch

Collecting efficientnet_pytorch
  Downloading efficientnet_pytorch-0.7.1.tar.gz (21 kB)
  Preparing metadata (setup.py) ... [?25l- done
Building wheels for collected packages: efficientnet_pytorch
  Building wheel for efficientnet_pytorch (setup.py) ... [?25l- \ done
[?25h  Created wheel for efficientnet_pytorch: filename=efficientnet_pytorch-0.7.1-py3-none-any.whl size=16446 sha256=1492bac251fbf034459a17f3ae8a1a794a82cf14714272f9fd1483ef463e06d4
  Stored in directory: /root/.cache/pip/wheels/0e/cc/b2/49e74588263573ff778da58cc99b9c6349b496636a7e165be6
Successfully built efficientnet_pytorch
Installing collected packages: efficientnet_pytorch
Successfully installed efficientnet_pytorch-0.7.1
[0m

In [22]:
from efficientnet_pytorch import EfficientNet

model_B4 = EfficientNet.from_name('efficientnet-b4')
model_B4.load_state_dict(torch.load("../input/test-for-kaggle-0426/epoch25.pt", map_location='cuda'))

model_B4.to(device)

EfficientNet(
  (_conv_stem): Conv2dStaticSamePadding(
    3, 48, kernel_size=(3, 3), stride=(2, 2), bias=False
    (static_padding): ZeroPad2d((0, 1, 0, 1))
  )
  (_bn0): BatchNorm2d(48, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
  (_blocks): ModuleList(
    (0): MBConvBlock(
      (_depthwise_conv): Conv2dStaticSamePadding(
        48, 48, kernel_size=(3, 3), stride=[1, 1], groups=48, bias=False
        (static_padding): ZeroPad2d((1, 1, 1, 1))
      )
      (_bn1): BatchNorm2d(48, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
      (_se_reduce): Conv2dStaticSamePadding(
        48, 12, kernel_size=(1, 1), stride=(1, 1)
        (static_padding): Identity()
      )
      (_se_expand): Conv2dStaticSamePadding(
        12, 48, kernel_size=(1, 1), stride=(1, 1)
        (static_padding): Identity()
      )
      (_project_conv): Conv2dStaticSamePadding(
        48, 24, kernel_size=(1, 1), stride=(1, 1), bias=False
  

In [23]:
predictions = []
cnt = 0

B4_preds = []

with torch.no_grad():
    for image, label in tqdm(testing_dataloader):
        image = image.to(device)
        outputs = model_B4(image)
        for i in range(len(outputs)):
            B4_preds.append(outputs[i][:100])
#         resnet_preds.append(outputs[0][:100])
        preds = outputs.detach().cpu()
        predictions.append(preds.argmax(1))

100%|██████████| 739/739 [24:56<00:00,  2.03s/it]


In [24]:
B4_preds_ = []

for i in range(len(B4_preds)):
    B4_preds_.append(B4_preds[i].tolist())

In [25]:
B4_preds_list = pd.DataFrame({'0_class':[]})

for i in range(99, 0, -1): 
    B4_preds_list.insert(1, "{}_class".format(i), [])

for i in range(len(B5_preds_)):
    B4_preds_list.loc[i] = B4_preds_[i]

In [26]:
B4_preds_list.to_csv('B4_submission.csv', index=False)

B4_preds_list.head()

Unnamed: 0,0_class,1_class,2_class,3_class,4_class,5_class,6_class,7_class,8_class,9_class,...,90_class,91_class,92_class,93_class,94_class,95_class,96_class,97_class,98_class,99_class
0,-2.571729,5.977733,9.317001,4.836597,10.742593,-1.167939,2.632571,2.478869,2.740776,3.915955,...,5.307129,0.584284,-6.052864,3.654924,4.087985,-2.358377,0.874906,3.654458,-3.120832,2.567825
1,7.153302,5.257663,-1.771159,-2.009251,3.975573,0.78169,6.442517,-1.134265,2.609756,8.876288,...,6.548178,6.517748,1.61528,5.22577,1.504519,3.653682,0.868913,0.422606,2.36736,1.117713
2,-3.17049,5.333634,16.227261,6.155414,11.360859,-4.58211,0.47357,0.87626,6.30806,3.260907,...,-2.221341,-0.27514,-4.477224,5.252504,2.67506,-4.74551,2.464423,-1.880267,-1.138426,3.315015
3,-0.998833,-0.157031,1.312434,0.691659,2.927119,2.488154,4.697551,4.778651,4.995743,3.199096,...,0.44724,-2.934232,2.130785,0.838407,-1.519703,-0.061749,-1.849497,1.312758,1.17995,2.895253
4,-2.754803,10.598209,3.455296,4.492938,4.692679,1.566526,6.785053,-1.870538,7.657426,12.105474,...,-0.51993,-2.021164,3.169623,9.916356,3.480946,-1.801913,-2.876096,0.658748,11.843788,0.946106


# Ensemble Learning (b4+b5+resnet)

In [27]:
resnet_result = pd.read_csv('../input/resnet-submission/resnet_submission.csv')

resnet_result.head()

Unnamed: 0,0_class,1_class,2_class,3_class,4_class,5_class,6_class,7_class,8_class,9_class,...,90_class,91_class,92_class,93_class,94_class,95_class,96_class,97_class,98_class,99_class
0,-4.801618,-2.975569,-1.0109,0.004538,0.350488,-4.283353,-4.044071,-1.81717,-3.906837,-5.185993,...,-4.440704,-5.510637,-3.565371,-4.36757,-3.586168,-4.021678,-5.994962,-4.712767,-2.815238,-1.952368
1,-3.122581,-1.158474,-8.431793,-7.456934,-7.085815,-5.844799,-4.357646,-6.926689,-6.120798,0.864344,...,-1.683628,-4.94246,-5.104192,-3.478405,-3.778158,-4.185591,-8.658447,-5.647788,-5.539834,-7.571363
2,-5.502335,-1.471846,8.994921,-1.64459,2.780464,-4.725876,-4.252627,-5.074745,0.530009,-0.46011,...,-3.465601,-5.762954,-5.205477,-4.648644,-1.996166,-7.006385,-6.289888,-5.483791,-3.666237,-3.386745
3,-1.438631,-2.900598,-2.842964,-2.837878,-2.651607,-3.997262,-2.419579,-2.409816,-1.67164,-2.249619,...,-2.150546,-3.092925,-1.705112,-3.123335,-2.864271,-2.119378,-2.402761,-2.806396,-3.288818,-2.080209
4,-5.85615,-0.279349,-4.226741,-2.668807,-2.526486,-1.11652,-0.516609,-5.410573,-4.07262,-1.866208,...,-2.197257,-7.641562,-1.674403,-4.825066,-5.543388,-5.794561,-5.393475,-6.552042,0.802193,-5.098389


In [28]:
b5_result = pd.read_csv('../input/b5-sub/B5_submission.csv')

b5_result = b5_result * 10

b5_result.head()

Unnamed: 0,0_class,1_class,2_class,3_class,4_class,5_class,6_class,7_class,8_class,9_class,...,90_class,91_class,92_class,93_class,94_class,95_class,96_class,97_class,98_class,99_class
0,0.892773,-3.966359,-2.954552,-1.383131,-1.230138,-1.605695,7.928654,4.240064,0.428775,0.965414,...,8.436068,1.253415,-5.170112,3.990971,3.111282,-1.808589,-2.022345,-7.563506,2.165346,-2.866596
1,1.648182,-0.16481,-2.596343,0.25676,1.757425,-3.846923,1.825878,0.021879,1.588355,-0.191997,...,3.332699,-0.81718,-1.9121,-2.092699,0.625315,-5.305366,0.989622,-1.957383,-3.33418,-3.192983
2,-0.481743,0.391381,-0.247089,2.485051,-1.057789,0.69611,-4.091997,-1.179258,-0.500628,-3.4655,...,7.796316,-1.606795,2.412722,-0.056671,2.127215,1.371859,3.259758,-6.339055,0.069074,0.115981
3,-0.715939,1.363132,-2.073458,2.271285,-1.479022,-1.946633,1.016834,-2.028146,-2.438629,-0.708824,...,3.56992,3.384302,0.336098,-3.883107,-0.452094,-3.788221,6.431556,-1.021312,1.686525,1.532885
4,-1.381458,2.214819,4.023111,4.014326,1.586189,-1.834667,-0.580804,6.224906,-0.865272,-1.505553,...,4.452855,2.541638,-0.841258,1.238704,-3.511993,1.813707,1.622686,1.142952,0.504797,-0.824462


In [29]:
b4_result = pd.read_csv('../input/b4-sub/B4_submission.csv')

# b4_result = b4_result * 10

b4_result.head()

Unnamed: 0,0_class,1_class,2_class,3_class,4_class,5_class,6_class,7_class,8_class,9_class,...,90_class,91_class,92_class,93_class,94_class,95_class,96_class,97_class,98_class,99_class
0,-1.971418,7.601085,6.253662,4.200216,11.877116,-2.441239,5.244116,-0.190037,3.397834,4.547288,...,4.986811,1.825171,-7.296851,6.224649,3.114617,-0.30857,0.266424,0.494667,-5.222482,4.137463
1,7.472752,5.776456,-2.737694,-1.720482,3.65245,0.276067,6.608448,0.864106,4.592688,8.591151,...,7.866305,6.736653,2.71758,6.264142,2.302805,2.655147,1.018003,-0.706559,2.704113,1.300578
2,-3.553531,6.922027,15.746008,5.20561,11.518566,-2.183513,0.309439,1.676199,8.241667,3.346487,...,-0.024887,-0.938581,-4.564708,3.53074,1.694313,-5.83182,1.436764,-1.031949,-1.899167,3.243559
3,-1.805562,0.220727,0.537807,-0.447724,4.249485,1.42249,3.598542,4.122486,6.112689,3.032334,...,-0.533039,-2.219996,1.833291,1.130475,-0.753881,0.365807,-2.0438,1.86066,0.673981,3.255714
4,-1.025319,9.523963,3.477371,4.063594,4.372114,2.653487,5.072325,-1.769767,4.885382,8.662114,...,-0.438711,-0.779596,3.789347,6.610578,3.096414,0.834889,-1.003939,3.763621,12.943184,1.951309


In [30]:
b5_result['max_value'] = b5_result.max(axis=1)
b5_result['class'] = b5_result.idxmax(axis=1)

resnet_result['max_value'] = resnet_result.max(axis=1)
resnet_result['class'] = resnet_result.idxmax(axis=1)

b4_result['max_value'] = b4_result.max(axis=1)
b4_result['class'] = b4_result.idxmax(axis=1)

In [31]:
b5_result.head()

Unnamed: 0,0_class,1_class,2_class,3_class,4_class,5_class,6_class,7_class,8_class,9_class,...,92_class,93_class,94_class,95_class,96_class,97_class,98_class,99_class,max_value,class
0,0.892773,-3.966359,-2.954552,-1.383131,-1.230138,-1.605695,7.928654,4.240064,0.428775,0.965414,...,-5.170112,3.990971,3.111282,-1.808589,-2.022345,-7.563506,2.165346,-2.866596,8.436068,90_class
1,1.648182,-0.16481,-2.596343,0.25676,1.757425,-3.846923,1.825878,0.021879,1.588355,-0.191997,...,-1.9121,-2.092699,0.625315,-5.305366,0.989622,-1.957383,-3.33418,-3.192983,8.935465,29_class
2,-0.481743,0.391381,-0.247089,2.485051,-1.057789,0.69611,-4.091997,-1.179258,-0.500628,-3.4655,...,2.412722,-0.056671,2.127215,1.371859,3.259758,-6.339055,0.069074,0.115981,7.796316,90_class
3,-0.715939,1.363132,-2.073458,2.271285,-1.479022,-1.946633,1.016834,-2.028146,-2.438629,-0.708824,...,0.336098,-3.883107,-0.452094,-3.788221,6.431556,-1.021312,1.686525,1.532885,7.316633,29_class
4,-1.381458,2.214819,4.023111,4.014326,1.586189,-1.834667,-0.580804,6.224906,-0.865272,-1.505553,...,-0.841258,1.238704,-3.511993,1.813707,1.622686,1.142952,0.504797,-0.824462,8.706151,60_class


In [32]:
resnet_result.head()

Unnamed: 0,0_class,1_class,2_class,3_class,4_class,5_class,6_class,7_class,8_class,9_class,...,92_class,93_class,94_class,95_class,96_class,97_class,98_class,99_class,max_value,class
0,-4.801618,-2.975569,-1.0109,0.004538,0.350488,-4.283353,-4.044071,-1.81717,-3.906837,-5.185993,...,-3.565371,-4.36757,-3.586168,-4.021678,-5.994962,-4.712767,-2.815238,-1.952368,5.471979,86_class
1,-3.122581,-1.158474,-8.431793,-7.456934,-7.085815,-5.844799,-4.357646,-6.926689,-6.120798,0.864344,...,-5.104192,-3.478405,-3.778158,-4.185591,-8.658447,-5.647788,-5.539834,-7.571363,19.686892,55_class
2,-5.502335,-1.471846,8.994921,-1.64459,2.780464,-4.725876,-4.252627,-5.074745,0.530009,-0.46011,...,-5.205477,-4.648644,-1.996166,-7.006385,-6.289888,-5.483791,-3.666237,-3.386745,8.994921,2_class
3,-1.438631,-2.900598,-2.842964,-2.837878,-2.651607,-3.997262,-2.419579,-2.409816,-1.67164,-2.249619,...,-1.705112,-3.123335,-2.864271,-2.119378,-2.402761,-2.806396,-3.288818,-2.080209,13.621403,48_class
4,-5.85615,-0.279349,-4.226741,-2.668807,-2.526486,-1.11652,-0.516609,-5.410573,-4.07262,-1.866208,...,-1.674403,-4.825066,-5.543388,-5.794561,-5.393475,-6.552042,0.802193,-5.098389,4.296759,13_class


In [33]:
b4_result.head()

Unnamed: 0,0_class,1_class,2_class,3_class,4_class,5_class,6_class,7_class,8_class,9_class,...,92_class,93_class,94_class,95_class,96_class,97_class,98_class,99_class,max_value,class
0,-1.971418,7.601085,6.253662,4.200216,11.877116,-2.441239,5.244116,-0.190037,3.397834,4.547288,...,-7.296851,6.224649,3.114617,-0.30857,0.266424,0.494667,-5.222482,4.137463,14.639492,39_class
1,7.472752,5.776456,-2.737694,-1.720482,3.65245,0.276067,6.608448,0.864106,4.592688,8.591151,...,2.71758,6.264142,2.302805,2.655147,1.018003,-0.706559,2.704113,1.300578,20.06641,55_class
2,-3.553531,6.922027,15.746008,5.20561,11.518566,-2.183513,0.309439,1.676199,8.241667,3.346487,...,-4.564708,3.53074,1.694313,-5.83182,1.436764,-1.031949,-1.899167,3.243559,15.746008,2_class
3,-1.805562,0.220727,0.537807,-0.447724,4.249485,1.42249,3.598542,4.122486,6.112689,3.032334,...,1.833291,1.130475,-0.753881,0.365807,-2.0438,1.86066,0.673981,3.255714,19.067812,48_class
4,-1.025319,9.523963,3.477371,4.063594,4.372114,2.653487,5.072325,-1.769767,4.885382,8.662114,...,3.789347,6.610578,3.096414,0.834889,-1.003939,3.763621,12.943184,1.951309,12.943184,98_class


In [34]:
result = []

for i in range(len(b5_result)):
    max_val = max((b5_result['max_value'][i] * 0.3),(resnet_result['max_value'][i] * 1.7),(b4_result['max_value'][i] * 0.5))
    
    if max_val == (b5_result['max_value'][i] * 0.3):
        result.append((max_val, b5_result['class'][i]))
        
    elif max_val == (resnet_result['max_value'][i] * 1.7):
        result.append((max_val, resnet_result['class'][i]))
        
    else:
        result.append((max_val, b4_result['class'][i]))

In [35]:
result[:5]

[(9.302363729476928, '86_class'),
 (33.467715644836424, '55_class'),
 (15.291365242004394, '2_class'),
 (23.156384658813476, '48_class'),
 (7.3044905185699465, '13_class')]

In [36]:
import re
 
tmp = []

for i in range(len(result)):
    tmp.append(int(re.sub(r'[^0-9]', '', result[i][1])))
    
print(tmp)

[86, 55, 2, 48, 13, 63, 79, 57, 43, 77, 50, 10, 0, 40, 10, 96, 32, 64, 73, 54, 74, 63, 2, 98, 42, 18, 2, 43, 38, 6, 67, 88, 90, 88, 83, 60, 10, 0, 12, 74, 2, 65, 32, 72, 83, 33, 55, 87, 14, 7, 55, 12, 23, 49, 72, 70, 68, 71, 4, 66, 82, 26, 20, 47, 80, 93, 70, 51, 94, 84, 29, 10, 28, 57, 25, 36, 24, 59, 43, 79, 4, 44, 77, 86, 26, 77, 73, 8, 68, 30, 52, 39, 77, 17, 0, 79, 81, 57, 0, 35, 9, 34, 19, 57, 54, 32, 72, 80, 42, 90, 75, 57, 82, 2, 58, 30, 2, 29, 28, 13, 98, 15, 23, 53, 18, 62, 22, 11, 88, 18, 40, 6, 85, 37, 79, 44, 42, 81, 45, 23, 67, 25, 29, 37, 98, 66, 44, 68, 3, 79, 24, 40, 3, 91, 12, 93, 43, 22, 50, 55, 79, 53, 27, 64, 59, 91, 80, 28, 74, 42, 3, 10, 40, 17, 30, 88, 72, 16, 75, 47, 57, 78, 89, 22, 57, 10, 68, 84, 37, 98, 2, 49, 70, 67, 23, 39, 71, 74, 91, 89, 60, 97, 97, 26, 0, 58, 37, 13, 21, 28, 33, 55, 15, 61, 24, 37, 56, 65, 53, 57, 47, 81, 6, 18, 28, 25, 47, 23, 36, 16, 46, 8, 79, 58, 11, 70, 85, 63, 46, 91, 6, 70, 80, 41, 9, 49, 64, 23, 0, 12, 37, 71, 11, 19, 18, 67, 20

In [37]:
df_all = pd.read_csv('../input/sorghum-id-fgvc-9/train_cultivar_mapping.csv')
df_all.dropna(inplace=True)

unique_cultivars = list(df_all["cultivar"].unique())

predictions = [unique_cultivars[pred] for pred in tmp]

In [38]:
sub = pd.read_csv('../input/sorghum-id-fgvc-9/sample_submission.csv')
sub['cultivar'] = predictions
sub.to_csv('submission5.csv', index=False)
sub.head()

Unnamed: 0,filename,cultivar
0,1000005362.png,PI_218112
1,1000099707.png,PI_329333
2,1000135300.png,PI_92270
3,1000136796.png,PI_329256
4,1000292439.png,PI_156393
