### Import Libraries

In [1]:
import os
import glob

import cv2
import numpy as np

import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader, random_split
# from torchvision.models import resnet50, vit_b_16
# from torchvision import transforms
# from torchmetrics import Accuracy
from torchinfo import summary

from utilities import AITEX
from model_architectures import VAE, MiniUNet


### Get Data

In [2]:
defect_codes = {
    0: "Normal",
    2:	"Broken end",
    6:	"Broken yarn",
    10:	"Broken pick",
    16:	"Weft curling",
    19:	"Fuzzyball",
    22:	"Cut selvage",
    23:	"Crease",
    25:	"Warp ball",
    27:	"Knots",
    29:	"Contamination",
    30: "Nep",
    36:	"Weft crack",
}

class AITEXPatched(AITEX):
    def __init__(self, *args, normal_only=False, defect_only=False, **kwargs,):
        super(AITEXPatched, self).__init__(*args, **kwargs)

        self.patched_images = []
        self.patched_masks = []
        self.has_defect = []
        for index, img in enumerate(self.images):
            img_new = cv2.resize(img, (4096, 256)) / 255. 
            self.patched_images.extend([img_new[:,i:i+256] for i in range(0, 4096, 256)])

            mask_new = cv2.resize(self.masks[index], (4096, 256))
            mask_patches = [mask_new[:,i:i+256] for i in range(0, 4096, 256)]
            self.patched_masks.extend(mask_patches)

            self.has_defect.extend([1 if np.sum(x) > 0 else 0 for x in mask_patches])
        
        if normal_only:
            indices = [x for x, y in enumerate(self.has_defect) if y==0]
            self.patched_images = [self.patched_images[x] for x in indices]
            self.patched_masks = [self.patched_masks[x] for x in indices]
            self.has_defect = [self.has_defect[x] for x in indices]
        if defect_only:
            indices = [x for x, y in enumerate(self.has_defect) if y!=0]
            self.patched_images = [self.patched_images[x] for x in indices]
            self.patched_masks = [self.patched_masks[x] for x in indices]
            self.has_defect = [self.has_defect[x] for x in indices]
    
    def __len__(self):
        """Get length of full dataset."""
        return len(self.patched_images)    
    
    def __getitem__(self, idx):
        """Return specific index of dataset."""
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        img = self.patched_images[idx]
        
        return np.stack([img, img, img], axis=0), self.has_defect[idx]

In [8]:
root = os.path.abspath(os.path.join(os.getcwd(), ".."))
data_dir = os.path.join(root, "data")
aitex_dir = os.path.join(data_dir, "aitex")

data = AITEXPatched(aitex_dir, greyscale=True)#, normal_only=True)
num_samples = len(data)
train_samples = int(num_samples * 0.9)
val_samples = num_samples - train_samples
train, val = random_split(data, [train_samples, val_samples])

# bs = 32
# train_loader = DataLoader(train, batch_size=bs, shuffle=True)
# val_loader = DataLoader(val, batch_size=bs, shuffle=True)

### Experiment with ViT

In [4]:
from transformers import ViTFeatureExtractor, ViTForImageClassification
from datasets import load_metric

metric = load_metric("accuracy")
labels = {0: "normal", 1: "defect"}

feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained(
    'google/vit-base-patch16-224', 
    num_labels=2,
    id2label=labels, 
    label2id={y: x for x, y in labels.items()},
    ignore_mismatched_sizes=True
)

  from .autonotebook import tqdm as notebook_tqdm
  metric = load_metric("accuracy")
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([2]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [32]:
def process_example(example):
    inputs = feature_extractor(example[0], return_tensors="pt")
    inputs['pixel_values'] = inputs['pixel_values'].reshape((3, 224, 224))
    inputs["label"] = example[1]
    return inputs

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['label'] for x in batch])
    }

def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

train_ds = [process_example(x) for x in train]
val_ds = [process_example(x) for x in val]

In [40]:
from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir="./vit-base-beans",
  per_device_train_batch_size=16,
  evaluation_strategy="steps",
  num_train_epochs=4,
  fp16=True,
  save_steps=100,
  eval_steps=100,
  logging_steps=10,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  report_to='tensorboard',
  load_best_model_at_end=True,
)

In [41]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=feature_extractor,
)

In [42]:
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)

  0%|          | 0/444 [07:56<?, ?it/s]
                                                 
 23%|██▎       | 100/444 [03:17<01:40,  3.41it/s]

{'loss': 0.2558, 'learning_rate': 0.0001954954954954955, 'epoch': 0.09}


                                                 
 23%|██▎       | 100/444 [03:20<01:40,  3.41it/s]

{'loss': 0.2379, 'learning_rate': 0.000190990990990991, 'epoch': 0.18}


                                                 
 23%|██▎       | 100/444 [03:23<01:40,  3.41it/s]

{'loss': 0.1623, 'learning_rate': 0.0001864864864864865, 'epoch': 0.27}


                                                 
 23%|██▎       | 100/444 [03:26<01:40,  3.41it/s]

{'loss': 0.2618, 'learning_rate': 0.000181981981981982, 'epoch': 0.36}


                                                 
 23%|██▎       | 100/444 [03:29<01:40,  3.41it/s]

{'loss': 0.1765, 'learning_rate': 0.0001774774774774775, 'epoch': 0.45}


                                                 
 23%|██▎       | 100/444 [03:32<01:40,  3.41it/s]

{'loss': 0.232, 'learning_rate': 0.000172972972972973, 'epoch': 0.54}


                                                 
 23%|██▎       | 100/444 [03:35<01:40,  3.41it/s]

{'loss': 0.2027, 'learning_rate': 0.00016846846846846846, 'epoch': 0.63}


                                                 
 23%|██▎       | 100/444 [03:38<01:40,  3.41it/s]

{'loss': 0.182, 'learning_rate': 0.00016396396396396395, 'epoch': 0.72}


                                                 
 23%|██▎       | 100/444 [03:41<01:40,  3.41it/s]

{'loss': 0.1743, 'learning_rate': 0.00015945945945945947, 'epoch': 0.81}


                                                 
 23%|██▎       | 100/444 [03:44<01:40,  3.41it/s]

{'loss': 0.2311, 'learning_rate': 0.00015495495495495496, 'epoch': 0.9}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                 

[A[A                                         
 23%|██▎       | 100/444 [03:46<01:40,  3.41it/s]
[A
[A

{'eval_loss': 0.1840788871049881, 'eval_accuracy': 0.9695431472081218, 'eval_runtime': 1.911, 'eval_samples_per_second': 206.174, 'eval_steps_per_second': 13.082, 'epoch': 0.9}


                                                 
 23%|██▎       | 100/444 [03:51<01:40,  3.41it/s]

{'loss': 0.16, 'learning_rate': 0.00015045045045045046, 'epoch': 0.99}


                                                 
 23%|██▎       | 100/444 [03:53<01:40,  3.41it/s]

{'loss': 0.2694, 'learning_rate': 0.00014594594594594595, 'epoch': 1.08}


                                                 
 23%|██▎       | 100/444 [03:56<01:40,  3.41it/s]

{'loss': 0.2053, 'learning_rate': 0.00014144144144144144, 'epoch': 1.17}


                                                 
 23%|██▎       | 100/444 [04:00<01:40,  3.41it/s]

{'loss': 0.1654, 'learning_rate': 0.00013693693693693693, 'epoch': 1.26}


                                                 
 23%|██▎       | 100/444 [04:03<01:40,  3.41it/s]

{'loss': 0.1847, 'learning_rate': 0.00013243243243243243, 'epoch': 1.35}


                                                 
 23%|██▎       | 100/444 [04:06<01:40,  3.41it/s]

{'loss': 0.2323, 'learning_rate': 0.00012792792792792795, 'epoch': 1.44}


                                                 
 23%|██▎       | 100/444 [04:09<01:40,  3.41it/s]

{'loss': 0.1755, 'learning_rate': 0.00012342342342342344, 'epoch': 1.53}


                                                 
 23%|██▎       | 100/444 [04:12<01:40,  3.41it/s]

{'loss': 0.1418, 'learning_rate': 0.00011891891891891893, 'epoch': 1.62}


                                                 
 23%|██▎       | 100/444 [04:15<01:40,  3.41it/s]

{'loss': 0.2526, 'learning_rate': 0.00011441441441441443, 'epoch': 1.71}


                                                 
 23%|██▎       | 100/444 [04:18<01:40,  3.41it/s]

{'loss': 0.169, 'learning_rate': 0.00010990990990990993, 'epoch': 1.8}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                 

[A[A                                         
 23%|██▎       | 100/444 [04:20<01:40,  3.41it/s]
[A
[A

{'eval_loss': 0.1403149962425232, 'eval_accuracy': 0.9695431472081218, 'eval_runtime': 2.122, 'eval_samples_per_second': 185.674, 'eval_steps_per_second': 11.781, 'epoch': 1.8}


                                                 
 23%|██▎       | 100/444 [04:24<01:40,  3.41it/s]

{'loss': 0.2644, 'learning_rate': 0.0001054054054054054, 'epoch': 1.89}


                                                 
 23%|██▎       | 100/444 [04:27<01:40,  3.41it/s]

{'loss': 0.1742, 'learning_rate': 0.00010090090090090089, 'epoch': 1.98}


                                                 
 23%|██▎       | 100/444 [04:30<01:40,  3.41it/s]

{'loss': 0.2546, 'learning_rate': 9.639639639639641e-05, 'epoch': 2.07}


                                                 
 23%|██▎       | 100/444 [04:33<01:40,  3.41it/s]

{'loss': 0.1735, 'learning_rate': 9.18918918918919e-05, 'epoch': 2.16}


                                                 
 23%|██▎       | 100/444 [04:36<01:40,  3.41it/s]

{'loss': 0.2293, 'learning_rate': 8.738738738738738e-05, 'epoch': 2.25}


                                                 
 23%|██▎       | 100/444 [04:40<01:40,  3.41it/s]

{'loss': 0.2346, 'learning_rate': 8.288288288288289e-05, 'epoch': 2.34}


                                                 
 23%|██▎       | 100/444 [04:43<01:40,  3.41it/s]

{'loss': 0.1699, 'learning_rate': 7.837837837837838e-05, 'epoch': 2.43}


                                                 
 23%|██▎       | 100/444 [04:46<01:40,  3.41it/s]

{'loss': 0.2275, 'learning_rate': 7.387387387387387e-05, 'epoch': 2.52}


                                                 
 23%|██▎       | 100/444 [04:49<01:40,  3.41it/s]

{'loss': 0.1723, 'learning_rate': 6.936936936936938e-05, 'epoch': 2.61}


                                                 
 23%|██▎       | 100/444 [04:52<01:40,  3.41it/s]

{'loss': 0.1507, 'learning_rate': 6.486486486486487e-05, 'epoch': 2.7}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                 
[A                                              

 23%|██▎       | 100/444 [04:54<01:40,  3.41it/s]
[A
[A

{'eval_loss': 0.13638249039649963, 'eval_accuracy': 0.9695431472081218, 'eval_runtime': 1.801, 'eval_samples_per_second': 218.767, 'eval_steps_per_second': 13.881, 'epoch': 2.7}


                                                 
 23%|██▎       | 100/444 [04:58<01:40,  3.41it/s]

{'loss': 0.2365, 'learning_rate': 6.0360360360360365e-05, 'epoch': 2.79}


                                                 
 23%|██▎       | 100/444 [05:01<01:40,  3.41it/s]

{'loss': 0.1906, 'learning_rate': 5.585585585585585e-05, 'epoch': 2.88}


                                                 
 23%|██▎       | 100/444 [05:04<01:40,  3.41it/s]

{'loss': 0.1906, 'learning_rate': 5.135135135135135e-05, 'epoch': 2.97}


                                                 
 23%|██▎       | 100/444 [05:07<01:40,  3.41it/s]

{'loss': 0.1605, 'learning_rate': 4.684684684684685e-05, 'epoch': 3.06}


                                                 
 23%|██▎       | 100/444 [05:11<01:40,  3.41it/s]

{'loss': 0.2042, 'learning_rate': 4.234234234234234e-05, 'epoch': 3.15}


                                                 
 23%|██▎       | 100/444 [05:14<01:40,  3.41it/s]

{'loss': 0.1806, 'learning_rate': 3.783783783783784e-05, 'epoch': 3.24}


                                                 
 23%|██▎       | 100/444 [05:17<01:40,  3.41it/s]

{'loss': 0.1809, 'learning_rate': 3.3333333333333335e-05, 'epoch': 3.33}


                                                 
 23%|██▎       | 100/444 [05:20<01:40,  3.41it/s]

{'loss': 0.1593, 'learning_rate': 2.882882882882883e-05, 'epoch': 3.42}


                                                 
 23%|██▎       | 100/444 [05:23<01:40,  3.41it/s]

{'loss': 0.2037, 'learning_rate': 2.4324324324324327e-05, 'epoch': 3.51}


                                                 
 23%|██▎       | 100/444 [05:26<01:40,  3.41it/s]

{'loss': 0.2108, 'learning_rate': 1.981981981981982e-05, 'epoch': 3.6}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                 
[A                                              

 23%|██▎       | 100/444 [05:28<01:40,  3.41it/s]
[A
[A

{'eval_loss': 0.13753864169120789, 'eval_accuracy': 0.9695431472081218, 'eval_runtime': 2.067, 'eval_samples_per_second': 190.614, 'eval_steps_per_second': 12.095, 'epoch': 3.6}


                                                 
 23%|██▎       | 100/444 [05:33<01:40,  3.41it/s]

{'loss': 0.2383, 'learning_rate': 1.5315315315315316e-05, 'epoch': 3.69}


                                                 
 23%|██▎       | 100/444 [05:36<01:40,  3.41it/s]

{'loss': 0.1711, 'learning_rate': 1.0810810810810812e-05, 'epoch': 3.78}


                                                 
 23%|██▎       | 100/444 [05:39<01:40,  3.41it/s]

{'loss': 0.2704, 'learning_rate': 6.306306306306306e-06, 'epoch': 3.87}


                                                 
 23%|██▎       | 100/444 [05:42<01:40,  3.41it/s]

{'loss': 0.119, 'learning_rate': 1.801801801801802e-06, 'epoch': 3.96}


                                                 
100%|██████████| 444/444 [02:29<00:00,  2.98it/s]


{'train_runtime': 149.195, 'train_samples_per_second': 94.963, 'train_steps_per_second': 2.976, 'train_loss': 0.2015069006262599, 'epoch': 4.0}
***** train metrics *****
  epoch                    =        4.0
  train_loss               =     0.2015
  train_runtime            = 0:02:29.19
  train_samples_per_second =     94.963
  train_steps_per_second   =      2.976


In [44]:
train_results.metrics

{'train_runtime': 149.195,
 'train_samples_per_second': 94.963,
 'train_steps_per_second': 2.976,
 'train_loss': 0.2015069006262599,
 'epoch': 4.0}