In [None]:
# default_exp core

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
import re

In [None]:
#export
import transformers
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from dataclasses import asdict
from collections import OrderedDict
from typing import Union, Tuple, Sequence, Set
from numpy.random import RandomState
import numpy as np
from datasets import Dataset
from datasets import load_dataset
from pathlib import Path
from sklearn.model_selection import StratifiedShuffleSplit
from torchvision.transforms import (CenterCrop, 
                                    RandomErasing,
                                    RandomAutocontrast,
                                    Compose, 
                                    Normalize, 
                                    RandomHorizontalFlip,
                                    RandomResizedCrop, 
                                    Resize, 
                                    RandomAdjustSharpness,
                                    ToTensor)
import torch
from transformers import AutoFeatureExtractor, TrainingArguments, Trainer
from transformers import AutoModelForImageClassification
from datasets import load_metric
from rich import print
import re
from dataclasses import dataclass
from typing import Dict
import datasets
import pandas as pd

In [None]:
%%bash 
git lfs update --force

Updated Git hooks.


testing

In [None]:
from numpy.testing import assert_allclose
from toolz.dicttoolz import valmap
from collections import Counter
from toolz import frequencies

## Data loading

In [None]:
ds = load_dataset("davanstrien/hugit_hmd_flysheet", use_auth_token=True, streaming=False, split='train')

HTTPError: 401 Client Error: Unauthorized for url: https://huggingface.co/api/datasets/davanstrien/hugit_hmd_flysheet?full=true

In [None]:
f =  '/Users/dvanstrien/Documents/DS/hmd_flysheet_detection/data/Flysheet_data/CONTAINER/or_5268_fse002r/Users/dvanstrien/Documents/DS/hmd_flysheet_detection/data/Flysheet_data/CONTAINER/or_5268_fse002r (1).jpg.jpg'

In [None]:
f

In [None]:
f = re.sub(r"(\(\d\))","",f)

In [None]:
f.split('.')[0]

In [None]:
#export
def return_base_path_deduplicated(x):
    f = x['fpath']
    f = re.sub(r"(\(\d\))","",f)
    f = f.split(".")[0]
    f = f.rstrip()
    return {"clean_path": re.sub(r"(\(\d\))","",f)}

In [None]:
#export
def check_uniques(example, uniques, column='clean_path'):
    if example[column] in uniques:
        uniques.remove(example[column])
        return True
    else:
        return False

In [None]:
ds

In [None]:
#export
def drop_duplicates(ds):
    ds = ds.map(return_base_path_deduplicated)
    uniques = set(ds['clean_path'])
    ds = ds.filter(check_uniques, fn_kwargs={"uniques":uniques})
    return ds

In [None]:
ds = drop_duplicates(ds)

In [None]:
#export
def get_id(example):
    x = example["fpath"]
    x = Path(x).name.split("_")
    return {"id": "_".join(x[:2] if len(x) >= 3 else x[:3])}

In [None]:
ds = ds.map(get_id)

In [None]:
ds[0]

In [None]:
ds

## Train, valid, test splits

In [None]:
#export
def split_w_stratify(
    ds,
    test_size: Union[int, float],
    train_size: Union[int, float, None] = None,
    random_state: Union[int, RandomState, None] = None,
) -> Tuple[Dataset, Dataset]:
    labels = ds['label']
    label_array = np.array(labels)
    train_inds, valid_inds = next(
        StratifiedShuffleSplit(
            n_splits=2, test_size=test_size, random_state=random_state
        ).split(np.zeros(len(labels)), y=label_array)
    )
    return ds.select(train_inds), ds.select(valid_inds)

In [None]:
train, valid = split_w_stratify(ds, test_size=0.5)

test frequencies 

In [None]:
assert_allclose(train.shape, valid.shape,rtol=2)

In [None]:
train_freqs = frequencies(train['label'])
train_freqs

In [None]:
train_percentages =  OrderedDict(sorted(valmap(lambda x: x/len(train_freqs),train_freqs).items())).values()
train_percentages

In [None]:
valid_freqs = frequencies(valid['label'])
valid_percentages = OrderedDict(sorted(valmap(lambda x: x/len(valid_freqs),valid_freqs).items())).values()
valid_percentages

In [None]:
assert_allclose(list(train_percentages), list(valid_percentages), atol=1)

In [None]:
#export
def train_valid_split_w_stratify(
    ds,
    valid_size: Union[int,float]=None,
    test_size: Union[int, float]=0.3,
    train_size: Union[int, float, None] = None,
    random_state: Union[int, RandomState, None] = None,
) -> Tuple[Dataset,Dataset, Dataset]:
    train, valid_test = split_w_stratify(ds, test_size=test_size)
    valid, test = split_w_stratify(valid_test, test_size=test_size)
    return train, valid, test

In [None]:
train, valid, test = train_valid_split_w_stratify(ds)

In [None]:
#export
def prepare_dataset(ds):
    print("Preparing dataset...")
    print("dropping duplicates...")
    ds = drop_duplicates(ds)
    print("getting ID...")
    ds = ds.map(get_id)    
    print("creating train, valid, test splits...")
    train, valid, test = train_valid_split_w_stratify(ds)
    data = {"train": train, 
            "valid": valid, 
            "test": test}
    for k,v  in data.items():
        print(f"{k} has {len(v)} examples")
    return train,valid,test

In [None]:
ds = load_dataset("davanstrien/flysheet", use_auth_token=True, streaming=False, split='train')

In [None]:
train,valid,test = prepare_dataset(ds)
train,valid,test

## Augmentations 

In [None]:
model_checkpoint = "davanstrien/vit-base-patch16-224-in21k-base-manuscripts"

In [None]:
#export
def prepare_transforms(model_checkpoint, train_ds, valid_ds, test_ds=None):
    feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
    normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
    _train_transforms = Compose(
            [
                Resize((feature_extractor.size,feature_extractor.size)),
                RandomAdjustSharpness(0.1),
                RandomAutocontrast(),
                ToTensor(),
                normalize,
                RandomErasing()
            ]
        )

    _val_transforms = Compose(
            [
                Resize((feature_extractor.size, feature_extractor.size)),
                ToTensor(),
                normalize,
            ]
        )

    def train_transforms(examples):
        examples['pixel_values'] = [_train_transforms(image.convert("RGB")) for image in examples['image']]
        return examples

    def val_transforms(examples):
        examples['pixel_values'] = [_val_transforms(image.convert("RGB")) for image in examples['image']]
        return examples
    if test_ds is not None:
        test_ds.set_transform(val_transforms)
    train_ds.set_transform(train_transforms)
    valid_ds.set_transform(val_transforms)
    return train_ds, valid_ds, test_ds

In [None]:
train_ds, valid_ds, test_ds = prepare_transforms(model_checkpoint, train,valid, test)

In [None]:
train_ds[0]['pixel_values'].shape

In [None]:
#export
@dataclass
class FlyswotData:
    train_ds: datasets.arrow_dataset.Dataset
    valid_ds: datasets.arrow_dataset.Dataset
    test_ds: datasets.arrow_dataset.Dataset
    id2label: Dict[int,str]
    label2id: Dict[str,int]

In [None]:
#export
def prep_data(ds_checkpoint="davanstrien/flysheet", model_checkpoint=None):
    try:
        ds = load_dataset(ds_checkpoint, use_auth_token=True, streaming=False, split='train')
        labels = ds.info.features['label'].names
        id2label = dict(enumerate(labels))
        label2id = {v:k for k,v in id2label.items()}
        train, valid, test = prepare_dataset(ds)
        train_ds, valid_ds, test_ds = prepare_transforms(model_checkpoint, train, valid, test)
        return FlyswotData(train_ds, valid_ds, test_ds, id2label, label2id)
    except FileNotFoundError as e:
        print(f"{e} make sure you are logged into the Hugging Face Hub")

In [None]:
data = prep_data(model_checkpoint=model_checkpoint)
data

In [None]:
data

In [None]:
from dataclasses import asdict

In [None]:
train_ds, valid_ds, test_ds, id2label, label2id = asdict(data).values()

In [None]:
train_ds

## Model training 

In [None]:
#export
def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

In [None]:
from sklearn.metrics import classification_report

In [None]:
#export
def train_model(data, 
                model_checkpoint,
                num_epochs=50,
                hub_model_id="flyswot",
                tune=False,
               fp16=True):
    transformers.logging.set_verbosity_warning()
    train_ds, valid_ds, test_ds, id2label, label2id = asdict(data).values()
    print(train_ds)
    model = AutoModelForImageClassification.from_pretrained(model_checkpoint, num_labels=len(id2label),
                                                   id2label=id2label,
                                                  label2id=label2id, ignore_mismatched_sizes=True)
    feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
    args = TrainingArguments(
    "output_dir",
    save_strategy="epoch",
    evaluation_strategy="epoch",
    hub_model_id=f"flyswot/{hub_model_id}",
    overwrite_output_dir=True,
    push_to_hub=True,
    learning_rate=2e-5,
    per_device_train_batch_size=4, 
    per_device_eval_batch_size=4,
    num_train_epochs=num_epochs,
    weight_decay=0.1,disable_tqdm=False,
    fp16=fp16,
   # load_best_model_at_end=True,
  #  metric_for_best_model="f1",
    logging_dir='logs',
    remove_unused_columns=False,
    save_total_limit=10,
    optim="adamw_torch",
    seed=42,    
)
    f1 = load_metric("f1")
    
    def compute_metrics(eval_pred):
        predictions, labels = eval_pred
        id2label = model.config.id2label
        predictions = np.argmax(predictions, axis=1)
        # report = classification_report(labels,
        #               predictions, output_dict=True,zero_division=0)
        # per_label = {} 
        # for k,v in report.items():
        #     if k.isdigit():
        #         label = id2label[int(k)]
        #         metrics = v['f1-score']
        #         per_label[f"{label}_f1"] = metrics  
        return f1.compute(predictions=predictions, references=labels, average='macro')


    trainer = Trainer(model,
                      args,
    train_dataset=train_ds,
    eval_dataset=valid_ds,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    tokenizer=feature_extractor)
    trainer.train()
    return trainer

In [None]:
trainer = train_model(data, "facebook/deit-tiny-patch16-224",0.001, fp16=False, hub_model_id='test')

## Model management

In [None]:
data

## Model Evaluation 

In [None]:
outputs = trainer.predict(data.test_ds)

In [None]:
#export
def plot_confusion_matrix(outputs, trainer):
    from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
    import matplotlib.pyplot as plt
    fig, ax = plt.subplots(figsize=(15, 15))
    y_true = outputs.label_ids
    y_pred = outputs.predictions.argmax(1)
    labels =trainer.model.config.id2label.values()
    cm = confusion_matrix(y_true, y_pred)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
    disp.plot(xticks_rotation=45, ax=ax)


In [None]:
plot_confusion_matrix(outputs,trainer)

In [None]:
#export
def create_classification_report(outputs, trainer):
    from sklearn.metrics import classification_report
    y_true = outputs.label_ids
    y_pred = outputs.predictions.argmax(1)
    labels =trainer.model.config.id2label.values()
    return classification_report(y_true, y_pred, target_names=labels, output_dict=True)

In [None]:
results = create_classification_report(outputs, trainer,)

In [None]:
metrics = trainer.evaluate()
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

In [None]:
kwargs = {
        "tasks": "image-classification",
        "tags": ["image-classification", "vision"],
    }
   
trainer.push_to_hub(**kwargs)

In [None]:
misclasified report

In [None]:
#export
def create_misclassified_report(outputs,trainer, important_label=None, print_results=True):
    id2label = trainer.model.config.id2label
    y_true = outputs.label_ids
    y_pred = outputs.predictions.argmax(1)
    df = pd.DataFrame({"y_true":y_true,"y_pred": y_pred})
    df.y_true = df.y_true.map(id2label)
    df.y_pred = df.y_pred.map(id2label)
    if print_results:
        misclassified_df = df[df.y_true != df.y_pred]
        print('misclassified:')
        print(misclassified_df)
        print('\n')
        if important_label:
            print(f"Number of wrong predictions of {important_label} label: {len(misclassied[misclassied['y_pred']==important_label])}")
            print(f"Percentage of wrong predictions of {important_label} label: {(len(misclassied[misclassied['y_pred']==important_label])/len(df))*100}")
        return misclassified_df


In [None]:
trainer.push_to_hub()

In [None]:
import pandas as pd

In [None]:
y_true = outputs.label_ids
y_pred = outputs.predictions.argmax(1)
df = pd.DataFrame({"y_true":y_true,"y_pred": y_pred})