In [None]:
!pip install gdown

In [2]:
import gdown

url = 'https://drive.google.com/uc?id=1YQwBlGvqodie2zzurPIqZ7dswTHc6-9o'

output = 'MIDV-2020.zip'
gdown.download(url, output)

Downloading...
From: https://drive.google.com/uc?id=1YQwBlGvqodie2zzurPIqZ7dswTHc6-9o
To: /kaggle/working/MIDV-2020.zip
100%|██████████| 3.98G/3.98G [00:15<00:00, 256MB/s] 


'MIDV-2020.zip'

In [3]:
!unzip MIDV-2020.zip >/dev/null
print('[INFO] The dataset has been unzipped...')

[INFO] The dataset has been unzipped...


In [4]:
import io
import math
import os
import sys
import torch
import wandb
import albumentations as A
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sn
from albumentations.pytorch import ToTensorV2
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_curve
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision import models
from torchvision import datasets
from tqdm import tqdm

In [4]:
import wandb

# WandB – Log in to your wandb account so you can log all your metrics
wandb.login()

[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize


[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [6]:
# Define the name of the dataset
DATASET_NAME = 'MIDV-2020'

# Define the path to train and valid dataset
TRAIN_PATH = os.path.join(DATASET_NAME, 'train')
VALID_PATH = os.path.join(DATASET_NAME, 'test')

# Define the input image dimensions
IMAGE_SIZE = 224

# Define the labels
LABELS = sorted(os.listdir(os.path.join(TRAIN_PATH, 'images')))

In [7]:
class Transforms:
    def __init__(self):
        self.transforms = A.Compose([
            A.Resize(IMAGE_SIZE, IMAGE_SIZE),
            A.Normalize(),
            ToTensorV2()
        ])

    def __call__(self, img, *args, **kwargs):
        return self.transforms(image=np.array(img))['image']

# Create the train and valid datasets
train_dataset = datasets.ImageFolder(os.path.join(TRAIN_PATH, 'images'), transform=Transforms())
valid_dataset = datasets.ImageFolder(os.path.join(VALID_PATH, 'images'), transform=Transforms())

# Create the train and valid data loaders
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=os.cpu_count())
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=os.cpu_count())

In [10]:
# Determine the device to be used for training and training
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Initialize the pre-trained model
model = models.resnet50(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, len(LABELS))

# Send the model to the GPU if available
model = model.to(DEVICE)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


  0%|          | 0.00/97.8M [00:00<?, ?B/s]

In [10]:
def format_logs(metrics):
    logs = ["{} - {:.4}".format(k, v) for k, v in metrics.items()]
    return ', '.join(logs)

In [22]:
def run_epoch(epoch, data_loader, mode='train'):
    if mode == 'train':
        model.train()
    else:
        model.eval()
    
    y_pred = torch.zeros(0, dtype=torch.long, device='cpu')
    y_true = torch.zeros(0, dtype=torch.long, device='cpu')
    
    for batch_idx, (inputs, labels) in enumerate(data_loader, 1):
        (inputs, labels) = (inputs.to(DEVICE), labels.to(DEVICE))
        optimizer.zero_grad()
            
        with torch.set_grad_enabled(mode == 'train'):
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            _, preds = torch.max(outputs, 1)
                
            y_pred = torch.cat([y_pred, preds.view(-1).cpu()])
            y_true = torch.cat([y_true, labels.view(-1).cpu()])
            
        if mode == 'train':
            loss.backward()
            optimizer.step()
            
        metrics = {
            'loss': loss,
            'accuracy': accuracy_score(y_true, y_pred),
            'precision': precision_score(y_true, y_pred, average='macro', zero_division=1),
            'recall': recall_score(y_true, y_pred, average='macro', zero_division=1),
            'f1_score': f1_score(y_true, y_pred, average='macro', zero_division=1)
        }
        
        if batch_idx % 5 == 0:
            print('Epoch: {} - {}: {}% ({}/{}) [{}]'.format(epoch, mode,
                                                            int(100 * batch_idx / len(data_loader)),
                                                            batch_idx, len(data_loader),
                                                            format_logs(metrics)))
    
    matrix = confusion_matrix(y_pred, y_true)
    report = classification_report(y_true, y_pred, target_names=LABELS)
    report_dict = classification_report(y_true, y_pred, target_names=LABELS, output_dict=True)
    
    return {'matrix': matrix, 'report': report, 'report_dict': report_dict, 'metrics': metrics}

In [12]:
def save_matrix(matrix, file_name):
    df_cm = pd.DataFrame(matrix, index=[i for i in LABELS], columns=[i for i in LABELS]).astype(int)
    fig = plt.figure(figsize=(15,10))
    ax = sn.heatmap(df_cm, annot=True)
    ax.yaxis.set_ticklabels(ax.yaxis.get_ticklabels(), rotation=0, ha='right')
    ax.xaxis.set_ticklabels(ax.xaxis.get_ticklabels(), rotation=45, ha='right')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.savefig(file_name, bbox_inches='tight')
    plt.close(fig)

In [13]:
def get_report(report):
    data = []
    columns = ['', 'precision', 'recall', 'f1_score', 'support']
    for key in report:
        if key == 'accuracy':
            accuracy = format(report[key], '.4f')
            support = str(report['macro avg']['support'])
            data.append([key, '', '', accuracy, support])
        else:
            precision = format(report[key]['precision'], '.4f')
            recall = format(report[key]['recall'], '.4f')
            f1_score = format(report[key]['f1-score'], '.4f')
            support = str(report[key]['support'])
            data.append([key, precision, recall, f1_score, support])
            
    return wandb.Table(data=data, columns=columns)

In [None]:
# WandB – Initialize a new run
wandb.init(entity="jsun-dev", project="document-classification")
wandb.watch(model, log="all")

best_epoch = 0
best_loss = float('inf')
best_accuracy = 0.0
best_metrics = {}

# Train the model
for i in range(10):
    # Train and optimize on the training dataset
    train_logs = run_epoch(i, train_loader, mode='train')
    save_matrix(train_logs['matrix'], 'train_matrix.png')
    print('\n{}'.format(train_logs['report']))
    train_report = train_logs['report_dict']
    train_metrics = train_logs['metrics']
    
    # Evaluate on the validation dataset
    valid_logs = run_epoch(i, valid_loader, mode='valid')
    save_matrix(valid_logs['matrix'], 'valid_matrix.png')
    print('\n{}'.format(valid_logs['report']))
    valid_report = valid_logs['report_dict']
    valid_metrics = valid_logs['metrics']
    
    # Criteria for improved model
    valid_loss = valid_metrics['loss']
    valid_accuracy = valid_metrics['accuracy']
    better_accuracy = (valid_accuracy > best_accuracy)
    similar_accuracy_better_loss = (math.isclose(valid_accuracy, best_accuracy, rel_tol=1e-3)
                                    and valid_loss < best_loss)
    
    # Save model if improved
    if better_accuracy or similar_accuracy_better_loss:
        best_epoch = i
        best_loss = valid_loss
        best_accuracy = valid_accuracy
        best_metrics = valid_metrics
        wandb.run.summary['best_epoch'] = best_epoch
        wandb.run.summary['best_loss'] = best_metrics['loss']
        wandb.run.summary['best_accuracy'] = best_metrics['accuracy']
        wandb.run.summary['best_precision'] = best_metrics['precision']
        wandb.run.summary['best_recall'] = best_metrics['recall']
        wandb.run.summary['best_f1_score'] = best_metrics['f1_score']
        torch.save(model.state_dict(), './best_model.pth')
        print('Improved model saved!')
    
    # Log epoch results
    wandb.log({
        'train/loss': train_metrics['loss'],
        'train/accuracy': train_metrics['accuracy'],
        'train/precision': train_metrics['precision'],
        'train/recall': train_metrics['recall'],
        'train/f1_score': train_metrics['f1_score'],
        'valid/loss': valid_metrics['loss'],
        'valid/accuracy': valid_metrics['accuracy'],
        'valid/precision': valid_metrics['precision'],
        'valid/recall': valid_metrics['recall'],
        'valid/f1_score': valid_metrics['f1_score'],
        'report/train': get_report(train_report),
        'report/valid': get_report(valid_report),
        'matrix/train': wandb.Image('train_matrix.png'),
        'matrix/valid': wandb.Image('valid_matrix.png')
    })

wandb.run.finish()

<a href="./best_model.pth"> Download Model</a>