In [None]:
# default_exp core

# module name here

> 

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#export
from collections import OrderedDict
from typing import Union, Tuple, Sequence, Set
from numpy.random import RandomState
import numpy as np
from datasets import Dataset
from datasets import load_dataset
from pathlib import Path
from sklearn.model_selection import StratifiedShuffleSplit
from torchvision.transforms import (CenterCrop, 
                                    RandomErasing,
                                    RandomAutocontrast,
                                    Compose, 
                                    Normalize, 
                                    RandomHorizontalFlip,
                                    RandomResizedCrop, 
                                    Resize, 
                                    RandomAdjustSharpness,
                                    ToTensor)
import torch
from transformers import AutoFeatureExtractor, TrainingArguments, Trainer
from transformers import AutoModelForImageClassification
from datasets import load_metric
from rich import print

testing

In [None]:
from numpy.testing import assert_allclose
from toolz.dicttoolz import valmap
from collections import Counter
from toolz import frequencies

## Data loading

In [None]:
ds = load_dataset("davanstrien/flysheet", use_auth_token=True, streaming=False, split='train')

Using custom data configuration davanstrien--flysheet-2cdc8849e04b41c9
Reusing dataset parquet (/Users/dvanstrien/.cache/huggingface/datasets/parquet/davanstrien--flysheet-2cdc8849e04b41c9/0.0.0/0b6d5799bb726b24ad7fc7be720c170d8e497f575d02d47537de9a5bac074901)


In [None]:
f =  '/Users/dvanstrien/Documents/DS/hmd_flysheet_detection/data/Flysheet_data/CONTAINER/or_5268_fse002r/Users/dvanstrien/Documents/DS/hmd_flysheet_detection/data/Flysheet_data/CONTAINER/or_5268_fse002r (1).jpg.jpg'

In [None]:
f

'/Users/dvanstrien/Documents/DS/hmd_flysheet_detection/data/Flysheet_data/CONTAINER/or_5268_fse002r/Users/dvanstrien/Documents/DS/hmd_flysheet_detection/data/Flysheet_data/CONTAINER/or_5268_fse002r (1).jpg.jpg'

In [None]:
import re

In [None]:
f = re.sub(r"(\(\d\))","",f)

In [None]:
f.split('.')[0]

'/Users/dvanstrien/Documents/DS/hmd_flysheet_detection/data/Flysheet_data/CONTAINER/or_5268_fse002r/Users/dvanstrien/Documents/DS/hmd_flysheet_detection/data/Flysheet_data/CONTAINER/or_5268_fse002r '

In [None]:
def return_base_path_deduplicated(x):
    f = x['fpath']
    f = re.sub(r"(\(\d\))","",f)
    f = f.split(".")[0]
    f = f.rstrip()
    return {"clean_path": re.sub(r"(\(\d\))","",f)}

In [None]:
def check_uniques(example, uniques, column='clean_path'):
    if example[column] in uniques:
        uniques.remove(example[column])
        return True
    else:
        return False

In [None]:
ds

Dataset({
    features: ['image', 'label', 'fpath'],
    num_rows: 2061
})

In [None]:
def drop_duplicates(ds):
    ds = ds.map(return_base_path_deduplicated)
    uniques = set(ds['clean_path'])
    ds = ds.filter(check_uniques, fn_kwargs={"uniques":uniques})
    return ds

In [None]:
ds = drop_duplicates(ds)

0ex [00:00, ?ex/s]

  0%|          | 0/3 [00:00<?, ?ba/s]

In [None]:
def get_id(example):
    x = example["fpath"]
    x = Path(x).name.split("_")
    return {"id": "_".join(x[:2] if len(x) >= 3 else x[:3])}

In [None]:
ds = ds.map(get_id)

0ex [00:00, ?ex/s]

In [None]:
ds[0]

{'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=448x248 at 0x19DE98940>,
 'label': 0,
 'fpath': '/Users/dvanstrien/Documents/DS/hmd_flysheet_detection/data/Flysheet_data/CONTAINER/add_ms_10455_fse005r.jpg',
 'clean_path': '/Users/dvanstrien/Documents/DS/hmd_flysheet_detection/data/Flysheet_data/CONTAINER/add_ms_10455_fse005r',
 'id': 'add_ms'}

In [None]:
ds

Dataset({
    features: ['image', 'label', 'fpath', 'clean_path', 'id'],
    num_rows: 1223
})

## Train, valid, test splits

In [None]:
#export
def split_w_stratify(
    ds,
    test_size: Union[int, float],
    train_size: Union[int, float, None] = None,
    random_state: Union[int, RandomState, None] = None,
) -> Tuple[Dataset, Dataset]:
    labels = ds['label']
    label_array = np.array(labels)
    train_inds, valid_inds = next(
        StratifiedShuffleSplit(
            n_splits=2, test_size=test_size, random_state=random_state
        ).split(np.zeros(len(labels)), y=label_array)
    )
    return ds.select(train_inds), ds.select(valid_inds)

In [None]:
train, valid = split_w_stratify(ds, test_size=0.5)

test frequencies 

In [None]:
assert_allclose(train.shape, valid.shape,rtol=2)

In [None]:
train_freqs = frequencies(train['label'])
train_freqs

{5: 25, 2: 108, 3: 93, 0: 36, 6: 139, 1: 30, 4: 164, 7: 16}

In [None]:
train_percentages =  OrderedDict(sorted(valmap(lambda x: x/len(train_freqs),train_freqs).items())).values()
train_percentages

odict_values([4.5, 3.75, 13.5, 11.625, 20.5, 3.125, 17.375, 2.0])

In [None]:
valid_freqs = frequencies(valid['label'])
valid_percentages = OrderedDict(sorted(valmap(lambda x: x/len(valid_freqs),valid_freqs).items())).values()
valid_percentages

odict_values([4.5, 3.625, 13.625, 11.75, 20.625, 3.125, 17.375, 1.875])

In [None]:
assert_allclose(list(train_percentages), list(valid_percentages), atol=1)

In [None]:
#export
def train_valid_split_w_stratify(
    ds,
    valid_size: Union[int,float]=None,
    test_size: Union[int, float]=0.3,
    train_size: Union[int, float, None] = None,
    random_state: Union[int, RandomState, None] = None,
) -> Tuple[Dataset,Dataset, Dataset]:
    train, valid_test = split_w_stratify(ds, test_size=test_size)
    valid, test = split_w_stratify(valid_test, test_size=test_size)
    return train, valid, test

In [None]:
train, valid, test = train_valid_split_w_stratify(ds)

In [None]:
def prepare_dataset(ds):
    print("Preparing dataset...")
    print("dropping duplicates...")
    ds = drop_duplicates(ds)
    print("getting ID...")
    ds = ds.map(get_id)    
    print("creating train, valid, test splits...")
    train, valid, test = train_valid_split_w_stratify(ds)
    data = {"train": train, 
            "valid": valid, 
            "test": test}
    for k,v  in data.items():
        print(f"{k} has {len(v)} examples")
    return data

In [None]:
ds = load_dataset("davanstrien/flysheet", use_auth_token=True, streaming=False, split='train')

Using custom data configuration davanstrien--flysheet-2cdc8849e04b41c9
Reusing dataset parquet (/Users/dvanstrien/.cache/huggingface/datasets/parquet/davanstrien--flysheet-2cdc8849e04b41c9/0.0.0/0b6d5799bb726b24ad7fc7be720c170d8e497f575d02d47537de9a5bac074901)


In [None]:
data = prepare_dataset(ds)
data

Loading cached processed dataset at /Users/dvanstrien/.cache/huggingface/datasets/parquet/davanstrien--flysheet-2cdc8849e04b41c9/0.0.0/0b6d5799bb726b24ad7fc7be720c170d8e497f575d02d47537de9a5bac074901/cache-1dfba86c98cf68e0.arrow
Loading cached processed dataset at /Users/dvanstrien/.cache/huggingface/datasets/parquet/davanstrien--flysheet-2cdc8849e04b41c9/0.0.0/0b6d5799bb726b24ad7fc7be720c170d8e497f575d02d47537de9a5bac074901/cache-f34fde9b42e1bc61.arrow


Loading cached processed dataset at /Users/dvanstrien/.cache/huggingface/datasets/parquet/davanstrien--flysheet-2cdc8849e04b41c9/0.0.0/0b6d5799bb726b24ad7fc7be720c170d8e497f575d02d47537de9a5bac074901/cache-3c87c3e3b7dda5fa.arrow


{'train': Dataset({
     features: ['image', 'label', 'fpath', 'clean_path', 'id'],
     num_rows: 856
 }),
 'valid': Dataset({
     features: ['image', 'label', 'fpath', 'clean_path', 'id'],
     num_rows: 256
 }),
 'test': Dataset({
     features: ['image', 'label', 'fpath', 'clean_path', 'id'],
     num_rows: 111
 })}

## Augmentations 

In [None]:
model_checkpoint = "davanstrien/vit-base-patch16-224-in21k-base-manuscripts"

In [None]:
#export
def prepare_transforms(model_checkpoint, train_ds, valid_ds, test_ds=None):
    feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
    normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
    _train_transforms = Compose(
            [
                Resize((feature_extractor.size,feature_extractor.size)),
                RandomAdjustSharpness(0.1),
                RandomAutocontrast(),
                ToTensor(),
                normalize,
                RandomErasing()
            ]
        )

    _val_transforms = Compose(
            [
                Resize((feature_extractor.size, feature_extractor.size)),
                ToTensor(),
                normalize,
            ]
        )

    def train_transforms(examples):
        examples['pixel_values'] = [_train_transforms(image.convert("RGB")) for image in examples['image']]
        return examples

    def val_transforms(examples):
        examples['pixel_values'] = [_val_transforms(image.convert("RGB")) for image in examples['image']]
        return examples
    if test_ds is not None:
        test_ds.set_transform(val_transforms)
    train_ds.set_transform(train_transforms)
    valid_ds.set_transform(val_transforms)
    return train_ds, valid_ds, test_ds

In [None]:
train_ds, valid_ds, test_ds = prepare_transforms(model_checkpoint, train, valid, test)

In [None]:
train_ds[0]['pixel_values'].shape

torch.Size([3, 224, 224])

In [None]:
#export
def load_data(ds_checkpoint="davanstrien/flysheet", model_checkpoint=None):
    ds = load_dataset(ds_checkpoint, use_auth_token=True, streaming=False, split='train')
    labels = ds.info.features['label'].names
    id2label = dict(enumerate(labels))
    label2id = {v:k for k,v in id2label.items()}
    train, valid, test = train_valid_split_w_stratify(ds)
    train_ds, valid_ds, test_ds = prepare_transforms(model_checkpoint, train, valid, test)
    return train_ds, valid_ds, test_ds, id2label, label2id

## Model training 

In [None]:
#export
def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

In [None]:
from torch.optim import AdamW
import transformers



In [None]:
def train_model(ds_checkpoint, model_checkpoint, num_epochs, save_dir,tune=False):
    transformers.logging.set_verbosity_warning()
    train_ds, valid_ds, test_ds, id2label, label2id = load_data(ds_checkpoint,model_checkpoint=model_checkpoint)
    model = AutoModelForImageClassification.from_pretrained(model_checkpoint, num_labels=len(id2label),
                                                   id2label=id2label,
                                                  label2id=label2id, ignore_mismatched_sizes=True)
    feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
    if tune:
        disable_tqdm = True
    else:
        disable_tqdm = False
    args = TrainingArguments(
    f"save_dir/{model_checkpoint}_flyswot",
    save_strategy="epoch",
    evaluation_strategy="epoch",
    push_to_hub=False,
    learning_rate=2e-5,
    per_device_train_batch_size=4, 
    per_device_eval_batch_size=4,
    num_train_epochs=num_epochs,
    weight_decay=0.1,disable_tqdm=disable_tqdm,
    fp16=False,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    logging_dir='logs',
    remove_unused_columns=False,
    save_total_limit=10,optim=AdamW,
    seed=666,    
)
    f1 = load_metric("f1")

    def compute_metrics(eval_pred):
        predictions, labels = eval_pred
        predictions = np.argmax(predictions, axis=1)
        return f1.compute(predictions=predictions, references=labels, average='macro')
    import torch

    trainer = Trainer(model,
                      args,
    train_dataset=train_ds,
    eval_dataset=valid_ds,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    tokenizer=feature_extractor)
    trainer.train()

In [None]:
train_model('davanstrien/flysheet', "facebook/deit-tiny-patch16-224",0.1,'test',)

ERROR! Session/line number was not unique in database. History logging moved to new session 231


Using custom data configuration davanstrien--flysheet-2cdc8849e04b41c9
Reusing dataset parquet (/Users/dvanstrien/.cache/huggingface/datasets/parquet/davanstrien--flysheet-2cdc8849e04b41c9/0.0.0/0b6d5799bb726b24ad7fc7be720c170d8e497f575d02d47537de9a5bac074901)


Downloading:   0%|          | 0.00/160 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/68.0k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/21.9M [00:00<?, ?B/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at facebook/deit-tiny-patch16-224 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 192]) in the checkpoint and torch.Size([8, 192]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([8]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
from transformers import PretrainedConfig
from typing import List, Dict


class ResnetConfig(PretrainedConfig):
    model_type = "convnext"

    def __init__(
        self,
        num_channels=3,
        heads: int = 8,
        **kwargs,
    ):
        self.num_channels = num_channels
        self.heads:List[int] = heads
        super().__init__(**kwargs)

In [None]:
resnet50d_config = ResnetConfig(heads=[3,2])

In [None]:
resnet50d_config

ResnetConfig {
  "heads": [
    3,
    2
  ],
  "model_type": "convnext",
  "num_channels": 3,
  "transformers_version": "4.16.2"
}

In [None]:
from transformers import PreTrainedModel
from timm.models.resnet import ResNet


class ResnetModel(PreTrainedModel):
    config_class = ResnetConfig

    def __init__(self, config):
        super().__init__(config)
        block_layer = BLOCK_MAPPING[config.block_type]
        self.model = ResNet(
        )

    def forward(self, tensor):
        return self.model.forward_features(tensor)

In [None]:
from transformers import ConvNextModel

In [None]:
from transformers import AutoConfig

In [None]:
config = AutoConfig.from_pretrained("facebook/convnext-base-384-22k-1k")

Downloading:   0%|          | 0.00/68.0k [00:00<?, ?B/s]

In [None]:
model = ConvNextModel.from_pretrained("facebook/convnext-base-384-22k-1k")

Some weights of the model checkpoint at facebook/convnext-base-384-22k-1k were not used when initializing ConvNextModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing ConvNextModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ConvNextModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
import torch

In [None]:
def create_head(nf, n_out, lin_ftrs=None, ps=0.5, concat_pool=True, first_bn=True, bn_final=False,
                lin_first=False, y_range=None):
    "Model head that takes `nf` features, runs through `lin_ftrs`, and out `n_out` classes."
    if concat_pool: nf *= 2
    lin_ftrs = [nf, 512, n_out] if lin_ftrs is None else [nf] + lin_ftrs + [n_out]
    bns = [first_bn] + [True]*len(lin_ftrs[1:])
    ps = L(ps)
    if len(ps) == 1: ps = [ps[0]/2] * (len(lin_ftrs)-2) + ps
    actns = [nn.ReLU(inplace=True)] * (len(lin_ftrs)-2) + [None]
    pool = AdaptiveConcatPool2d() if concat_pool else nn.AdaptiveAvgPool2d(1)
    layers = [pool, Flatten()]
    if lin_first: layers.append(nn.Dropout(ps.pop(0)))
    for ni,no,bn,p,actn in zip(lin_ftrs[:-1], lin_ftrs[1:], bns, ps, actns):
        layers += LinBnDrop(ni, no, bn=bn, p=p, act=actn, lin_first=lin_first)
    if lin_first: layers.append(nn.Linear(lin_ftrs[-2], n_out))
    if bn_final: layers.append(nn.BatchNorm1d(lin_ftrs[-1], momentum=0.01))
    if y_range is not None: layers.append(SigmoidRange(*y_range))
    return nn.Sequential(*layers)

In [None]:
class MultiModel(Module):
    "A two-headed model given a `body` and `n` output features"

    def __init__(self, body: nn.Sequential, n: L):
        nf = num_features_model(nn.Sequential(*body.children()))  #* (2)
        self.body = body
        self.compressed_labels = create_head(nf, n[0])
        self.label = create_head(nf, n[1])

    def forward(self, x):
        y = self.body(x)
        compressed_labels = self.compressed_labels(y)
        label = self.label(y)
        return [compressed_labels, label]

In [None]:
from transformers import ConvNextPreTrainedModel

In [None]:
from torch import nn

In [None]:
class ConvNextForImageClassification(ConvNextPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.num_labels_1 = config.num_labels_1
        self.num_labels_2 = confgi.num_labels_2
        self.convnext = ConvNextModel(config)

        # Classifier head
        self.classifier1 = (
            nn.Linear(config.hidden_sizes[-1], config.num_labels_1) if config.num_labels > 0 else nn.Identity()
        )
        self.classifier2 = (
            nn.Linear(config.hidden_sizes[-1], config.num_labels_2) if config.num_labels > 0 else nn.Identity()
        )

        # Initialize weights and apply final processing
        self.post_init()
    def forward(self, pixel_values=None, labels=None, output_hidden_states=None, return_dict=None):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.convnext(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)

        pooled_output = outputs.pooler_output if return_dict else outputs[1]

        logits = self.classifier(pooled_output)

        loss = None
        if labels is not None:
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                if self.num_labels == 1:
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(logits, labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)
        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return ConvNextClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
        )

## Model management