# Setup and imports

In [1]:
!pip install git+https://github.com/octheo/futur.git

Collecting git+https://github.com/octheo/futur.git
  Cloning https://github.com/octheo/futur.git to /tmp/pip-req-build-gfwos_qe
  Running command git clone --filter=blob:none --quiet https://github.com/octheo/futur.git /tmp/pip-req-build-gfwos_qe
  Resolved https://github.com/octheo/futur.git to commit 31a61f3a9affe9169226ff3a548239fc1b67521d
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting torcheval (from thor==0.1)
  Downloading torcheval-0.0.7-py3-none-any.whl.metadata (8.6 kB)
Downloading torcheval-0.0.7-py3-none-any.whl (179 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m179.2/179.2 kB[0m [31m12.1 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: thor
  Building wheel for thor (setup.py) ... [?25l[?25hdone
  Created wheel for thor: filename=thor-0.1-py3-none-any.whl size=7580 sha256=646bca7d7375d7211400d96d7f57312355afe8f50854d0c374e453b8b6d777f7
  Stored in directory: /tmp/pip-ephem-wheel-cache-zd7nnu59/w

In [2]:
import os
import re
import glob
import math

import wandb
from torchvision import transforms, models
from torch.utils.data import DataLoader
import torch
import torch.nn.init as init
import torch.optim as optim

from thor.splitters import supervised_AD as SAD
from thor.trainers import trainers, metrics
from thor.trainers import loss
from thor.datasets import mvtech

# WandB token config

In [3]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
wandb_key = user_secrets.get_secret("WB-SupervisedImageClassification")

# Run config

In [4]:
base_directory = "/kaggle/input/mvtech-anomaly-detection/mvtec_anomaly_detection"

In [5]:
categories = {
    'bottle': True,
    'cable': True,
    'capsule': True,
    'carpet': True,
    'grid': True,
    'hazelnut': True,
    'leather': True,
    'metal_nut': True,
    'pill': True,
    'screw': True,
    'tile': True,
    'toothbrush': True,
    'transistor': True,
    'wood': True,
    'zipper': True
}

selected_classes = [key for key, value in categories.items() if value]
print(selected_classes)

['bottle', 'cable', 'capsule', 'carpet', 'grid', 'hazelnut', 'leather', 'metal_nut', 'pill', 'screw', 'tile', 'toothbrush', 'transistor', 'wood', 'zipper']


In [6]:
conf = {
    "WandB-activated": True,
    "train_classes": selected_classes,
    "num_classes": len(selected_classes),
    "train_task": 'Image classification',
    "fine-tuned backbone": False,
    "cls_head": "linear",

    "batch_size" : 32,
    "epochs" : 2,
    "lr" : 1e-4,
    "l2_decay": 0.,
    'metrics': ["f1", "precision", "recall"],
    
    'img_size': (224,224),
    'loss': 'CrossEntropy',
}

# Model

In [7]:
from transformers import ViTModel, ViTFeatureExtractor, ViTImageProcessor, ViTForImageClassification

model_name = "google/vit-base-patch16-224"
conf["model"] = model_name

processor = ViTImageProcessor.from_pretrained(model_name)
vit_model = ViTModel.from_pretrained(model_name)

if not conf["fine-tuned backbone"]:
    for param in vit_model.parameters():
        param.requires_grad = False

preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/69.7k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
class LinearClassifier(torch.nn.Module):
    def __init__(self, vit_model, num_classes):
        super(LinearClassifier, self).__init__()
        self.vit = vit_model
        self.classifier = torch.nn.Linear(vit_model.config.hidden_size, num_classes)

    def forward(self, pixel_values):
        outputs = self.vit(pixel_values=pixel_values, output_hidden_states=True)
        cls_token = outputs.hidden_states[-1][:,0,:]
        logits = self.classifier(cls_token)
        return logits

In [9]:
base_model = LinearClassifier(vit_model, conf["num_classes"])

# Pipeline

In [10]:
splits = []
for c in selected_classes:
    splits.append(SAD.MVTech_AD_supervised_cls_split(base_directory, c, train_split=1, val_split=0, dist_adjust=False, multiclass=False))

In [11]:
train = []
val = []
test = []
for i, split in enumerate(splits):
    for j, sample in enumerate(split.no_defect_samples):
        split.no_defect_samples[j] = (sample, i)
    train += split.no_defect_samples[:math.ceil(0.8*split.nb_no_defect_samples)]
    val += split.no_defect_samples[math.ceil(0.8*split.nb_no_defect_samples):math.ceil(0.9*split.nb_no_defect_samples)]
    test += split.no_defect_samples[math.ceil(0.9*split.nb_no_defect_samples):]

In [12]:
train_transform = transforms.Compose([transforms.Resize(size=conf['img_size']),
                                     transforms.ToTensor()])

In [13]:
datasets = [mvtech.MVTechDataset_cls(split, transform=train_transform) for split in (train, val, test)]
dataloaders = [DataLoader(dataset, batch_size=conf["batch_size"], shuffle=True) for dataset in datasets]

In [14]:
if conf["WandB-activated"]:
    wandb.login(key=wandb_key)
    
    run = wandb.init(
        project="Classification",
        config=conf
    )

    model_artifact = wandb.Artifact("ViT-B", 
                                    type="model",
                                    description="Base training",
                                    metadata=conf
                                   )

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mtheomoreau-thor[0m ([33mtheomoreau-thor-octo-technology[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Tracking run with wandb version 0.19.1
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20250319_155351-vjtl0yxa[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mmajor-monkey-10[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/theomoreau-thor-octo-technology/Classification[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/theomoreau-thor-octo-technology/Classification/runs/vjtl0yxa[0m


# Model training

In [15]:
optimizer = optim.Adam(base_model.classifier.parameters(), lr=conf['lr'], weight_decay=conf["l2_decay"])
m = metrics.ClassificationMetrics(conf["num_classes"], conf['metrics'])
l = loss.ClassificationLoss("CE")

trainer = trainers.ClassificationTrainer(optimizer, l, m)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

for param in base_model.classifier.parameters():
    param.requires_grad = True

save_path = "/kaggle/working/model.pth"

trainer.train(
    base_model, 
    conf["epochs"], 
    dataloaders[0], 
    dataloaders[1], 
    conf["num_classes"], 
    device, 
    save_path=save_path, 
    wandb_run=run, 
    model_artifact=model_artifact
)

Train epoch 0: 100%|██████████| 103/103 [03:44<00:00,  2.18s/Batch]
Metric computation: 100%|██████████| 3/3 [07:08<00:00, 142.90s/Metric]
Val epoch 0: 100%|██████████| 13/13 [00:25<00:00,  1.98s/Batch]
Metric computation: 100%|██████████| 3/3 [00:53<00:00, 17.77s/Metric]


train loss: 0.00010508100739588925, val loss: 0.000983919482678175


Train epoch 1: 100%|██████████| 103/103 [02:24<00:00,  1.40s/Batch]
Metric computation: 100%|██████████| 3/3 [07:01<00:00, 140.48s/Metric]
Val epoch 1: 100%|██████████| 13/13 [00:17<00:00,  1.37s/Batch]
Metric computation: 100%|██████████| 3/3 [00:51<00:00, 17.07s/Metric]


train loss: 3.0163769135434252e-05, val loss: 0.0005945794594784578




# Model evaluation

In [16]:
trainer.validate_model(base_model, dataloaders[2], conf["num_classes"], device, run, log_images=True)

Metric computation: 100%|██████████| 3/3 [00:51<00:00, 17.27s/Metric]


In [17]:
wandb.finish()

[34m[1mwandb[0m: uploading output.log; uploading working/model.pth; uploading config.yaml
[34m[1mwandb[0m: uploading working/model.pth
[34m[1mwandb[0m: uploading working/model.pth; uploading history steps 238-238, summary, console lines 11-11
[34m[1mwandb[0m: uploading working/model.pth
[34m[1mwandb[0m:                                                                                
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run history:
[34m[1mwandb[0m:  avg_train_loss █▁
[34m[1mwandb[0m:    avg_val_loss █▁
[34m[1mwandb[0m:        train_f1 ▁█
[34m[1mwandb[0m:      train_loss █▆▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
[34m[1mwandb[0m: train_precision ▁█
[34m[1mwandb[0m:    train_recall ▁█
[34m[1mwandb[0m:          val_f1 ▁█
[34m[1mwandb[0m:        val_loss ▂▄▄▂▃▄▂▂▂▃█▂▂▂▁▁▁▁▁▂▃▁▁▁▁▁
[34m[1mwandb[0m:   val_precision ▁█
[34m[1mwandb[0m:      val_recall ▁█
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run summary:
[34m[1mwandb[0m:      avg_train_loss 