In [1]:
# Tez is currently not available on kaggle but you can install it using pip 
# or just add tez-lib dataset to the python path.
# If internet is enabled (and allowed), you can just install using pip
!pip install tez

Collecting tez
  Downloading tez-0.6.3-py3-none-any.whl (19 kB)
Installing collected packages: tez
Successfully installed tez-0.6.3


In [2]:
!pip install timm

Collecting timm
  Downloading timm-0.5.4-py3-none-any.whl (431 kB)
     |████████████████████████████████| 431 kB 559 kB/s            
Installing collected packages: timm
Successfully installed timm-0.5.4


In [3]:
# Everything becomes easy and intuitive from here. 
# Also, Tez keeps your code clean and readable!
# Let's import a few things.

import glob
import os
import albumentations
import timm
import torch
import torch.nn as nn
from sklearn import metrics, preprocessing, model_selection

from tez import Tez, TezConfig
from tez.callbacks import EarlyStopping
from tez.datasets import ImageDataset


import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

%matplotlib inline

In [4]:

INPUT_PATH = "../input/instacities1m/"
IMAGE_PATH = "../input/instacities1m/InstaCities1M/img_resized_1M/cities_instagram"
MODEL_PATH = "../working/"
MODEL_NAME = "resnet18"
#MODEL_NAME = os.path.basename(__file__)[:-3]
TRAIN_BATCH_SIZE = 32
VALID_BATCH_SIZE = 32
EPOCHS = 20
IMAGE_SIZE = 300

In [5]:
# Let's define a model now
# We inherit from tez.Model instead of nn.Module
# we have monitor_metrics if we want to monitor any metrics
# except the loss
# and we return 3 values in forward function.

class InstaModel(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        self.model = timm.create_model("resnet18", pretrained=True)
        n_features = self.model.fc.in_features
        self.model.fc = nn.Linear(n_features, num_classes)
        
    def monitor_metrics(self, outputs, targets):
        device = targets.get_device()
        if targets is None:
            return {}
        outputs = torch.argmax(outputs, dim=1).cpu().detach().numpy()
        targets = targets.cpu().detach().numpy()
        accuracy = metrics.accuracy_score(targets, outputs)
        return {"accuracy": torch.tensor(accuracy, device=device)}
    
    def optimizer_scheduler(self):
        opt = torch.optim.Adam(self.parameters(), lr=1e-3)
        sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
            opt,
            factor=0.5,
            patience=2,
            verbose=True,
            mode="max",
            threshold=1e-4,
        )
        return opt, sch
  
    def forward(self, image, targets=None):

        outputs = self.model(image)
        
        if targets is not None:
            loss = nn.CrossEntropyLoss()(outputs, targets)
            metrics = self.monitor_metrics(outputs, targets)
            return outputs, loss, metrics
        return outputs, None, None

In [6]:
dfx = pd.read_csv(INPUT_PATH + "test.csv")
dfx = dfx.dropna().reset_index(drop=True)
dfx["path"] = dfx["category"].astype(str) + "/" + dfx["id"].astype(str) + ".jpg"
    
lbl_enc = preprocessing.LabelEncoder()
dfx.category = lbl_enc.fit_transform(dfx.category.values)





test_image_paths = [os.path.join(IMAGE_PATH, x ) for x in dfx.path.values]
test_targets = dfx.category.values

dataset_aug = albumentations.Compose(
    [
    albumentations.Resize(256, 256)
    ]
)


test_dataset = ImageDataset(
    image_paths=test_image_paths,
    targets=test_targets,
    augmentations=dataset_aug,
    backend="cv2"


)


model = InstaModel(num_classes=dfx.category.nunique())
es = EarlyStopping(
    monitor="valid_loss",
    model_path=os.path.join(MODEL_PATH, MODEL_NAME + ".bin"),
    patience=3,
    mode="min",
)

model = Tez(model)
config = TezConfig(
    training_batch_size=TRAIN_BATCH_SIZE,
    validation_batch_size=VALID_BATCH_SIZE,
    epochs=EPOCHS,
    step_scheduler_after="epoch",
    step_scheduler_metric="valid_loss",
)

model.load("../input/using-tez-in-leaf-disease-classification/resnet18.bin")

preds = model.predict(test_dataset, batch_size=32, n_jobs=-1)
predictions = list()
for yhat in preds:
    predictions.extend(yhat)
    
np.savetxt("./predictions_text.csv", 
       predictions,
       delimiter =", ", 
       fmt ='% s')
predictions = np.argmax(predictions, axis=1)
print("Precision image: " + str(metrics.accuracy_score(dfx.category.values, predictions)))



Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth


Precision image: 0.27395
