In [None]:
!pip install tez
!pip install efficientnet-pytorch

## Importing necessary modules

In [None]:
import os
import albumentations
import pandas as pd

import tez
from tez.datasets import ImageDataset
from tez.callbacks import EarlyStopping

import torch
import torch.nn as nn
from torch.nn import functional as F

from efficientnet_pytorch import EfficientNet
from sklearn import metrics, model_selection, preprocessing

## Tez Model

In [None]:
model_check = EfficientNet.from_pretrained('efficientnet-b4')
img=torch.randn(10,3,240,240)
x = model_check.extract_features(img)
x=F.adaptive_avg_pool2d(x, 1)
print(x.shape) # torch.Size([1, 1792, 7, 7])

In [None]:
class LeafModel(tez.Model):
    def __init__(self, num_classes):
        super().__init__()

        self.effnet = EfficientNet.from_pretrained("efficientnet-b4")
        self.dropout = nn.Dropout(0.1)
        self.out = nn.Linear(1792, num_classes)
        self.step_scheduler_after = "epoch"
        
    def monitor_metrics(self, outputs, targets):
        if targets is None:
            return {}
        outputs = torch.argmax(outputs, dim=1).cpu().detach().numpy()
        targets = targets.cpu().detach().numpy()
        accuracy = metrics.accuracy_score(targets, outputs)
        return {"accuracy": accuracy}
    
    def fetch_optimizer(self):
        opt = torch.optim.Adam(self.parameters(), lr=3e-4)
        return opt
    
    def fetch_scheduler(self):
        sch = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            self.optimizer, T_0=10, T_mult=1, eta_min=1e-6, last_epoch=-1
        )
        return sch
    
    def forward(self, image, targets=None):
        batch_size, _, _, _ = image.shape

        x = self.effnet.extract_features(image)
        x = F.adaptive_avg_pool2d(x, 1).reshape(batch_size, -1)
        outputs = self.out(self.dropout(x))

        if targets is not None:
            loss = nn.CrossEntropyLoss()(outputs, targets)
            metrics = self.monitor_metrics(outputs, targets)
            return outputs, loss, metrics
        return outputs, None, None

# Data Augmentation

In [None]:
train_aug = albumentations.Compose([
            albumentations.RandomResizedCrop(256, 256),
            albumentations.Transpose(p=0.5),
            albumentations.HorizontalFlip(p=0.5),
            albumentations.VerticalFlip(p=0.5),
            albumentations.ShiftScaleRotate(p=0.5),
            albumentations.HueSaturationValue(
                hue_shift_limit=0.2, 
                sat_shift_limit=0.2, 
                val_shift_limit=0.2, 
                p=0.5
            ),
            albumentations.RandomBrightnessContrast(
                brightness_limit=(-0.1,0.1), 
                contrast_limit=(-0.1, 0.1), 
                p=0.5
            ),
            albumentations.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225], 
                max_pixel_value=255.0, 
                p=1.0
            ),
            albumentations.CoarseDropout(p=0.5),
            albumentations.Cutout(p=0.5)], p=1.)
  
        
valid_aug = albumentations.Compose([
            albumentations.CenterCrop(256, 256, p=1.),
            albumentations.Resize(256, 256),
            albumentations.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225], 
                max_pixel_value=255.0, 
                p=1.0
            )], p=1.)

## Make train and valid data by reading the csv file

In [None]:
df=pd.read_csv(r'../input/cassava-leaf-disease-classification/train.csv')

df_train,df_valid= model_selection.train_test_split(
        df, test_size=0.1, random_state=42, stratify=df.label.values
)
df_train['label'].value_counts(normalize=True)

In [None]:
df_valid['label'].value_counts(normalize=True)

In [None]:
image_path = "../input/cassava-leaf-disease-classification/train_images"

train_imgs_paths = [os.path.join(image_path,x) for x in df_train['image_id'].values]
valid_imgs_paths =[ os.path.join(image_path,x) for x in df_valid['image_id'].values]
train_targets = df_train.label.values
valid_targets = df_valid.label.values


train_dataset = ImageDataset(
    image_paths=train_imgs_paths,
    targets=train_targets,
    augmentations=train_aug,
)

valid_dataset = ImageDataset(
    image_paths=valid_imgs_paths,
    targets=valid_targets,
    augmentations=valid_aug,
)

# Load, Train & Save Model

In [None]:
model = LeafModel(num_classes=df.label.nunique())
es = EarlyStopping(
    monitor="valid_loss", model_path="model.bin", patience=3, mode="min"
)
model.fit(
    train_dataset,
    valid_dataset=valid_dataset,
    train_bs=32,
    valid_bs=64,
    device="cuda",
    epochs=10,
    callbacks=[es],
    fp16=True,
)
model.save("model.bin")

## Test the model

In [None]:
test_dfx = pd.read_csv("../input/cassava-leaf-disease-classification/sample_submission.csv")

image_path = "../input/cassava-leaf-disease-classification/test_images/"
test_image_paths = [os.path.join(image_path, x) for x in test_dfx.image_id.values]
# fake targets
test_targets = test_dfx.label.values


test_aug = albumentations.Compose([
            albumentations.CenterCrop(256, 256, p=1.),
            albumentations.Resize(256, 256),
            albumentations.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225], 
                max_pixel_value=255.0, 
                p=1.0
            )], p=1.)

test_dataset = ImageDataset(
    image_paths=test_image_paths,
    targets=test_targets,
    augmentations=test_aug,
)

In [None]:
preds = model.predict(test_dataset, batch_size=32, n_jobs=-1)
final_preds = None
for p in preds:
    if final_preds is None:
        final_preds = p
    else:
        final_preds = np.vstack((final_preds, p))
final_preds = final_preds.argmax(axis=1)
test_dfx.label = final_preds

In [None]:
test_dfx

In [None]:
test_dfx