### Import Libraries

- note: a no-intelligence model will perform at 0.5732 accuracy

In [89]:
import os
import glob

import cv2
import numpy as np

import matplotlib.pyplot as plt

import torch
from torch.optim import Adam
from torch.utils.data import DataLoader, random_split
from torchvision.models import resnet50, vit_b_16
from torchvision import transforms
from torchmetrics import Accuracy

from utilities import AITEX


### Get Data

In [74]:
defect_codes = {
    0: "Normal",
    2:	"Broken end",
    6:	"Broken yarn",
    10:	"Broken pick",
    16:	"Weft curling",
    19:	"Fuzzyball",
    22:	"Cut selvage",
    23:	"Crease",
    25:	"Warp ball",
    27:	"Knots",
    29:	"Contamination",
    30: "Nep",
    36:	"Weft crack",
}

class AITEXClassification(AITEX):
    def __init__(self, *args, **kwargs):
        super(AITEXClassification, self).__init__(*args, **kwargs)

        self.image_tensors = []
        resize = transforms.Resize((256, 4096))
        means = []
        stdevs = []
        for img in self.images:
        #     # img_fft = np.log(abs(np.fft.fftshift(np.fft.fft2(img))))
            img_fft = np.log(abs(np.fft.fftshift(np.fft.fft2(img))))
            img_3ch = np.repeat(img_fft.reshape((1,) + img.shape), 3, 0)
            img_tensor = resize(torch.Tensor(img_3ch)) #.permute(2, 0, 1))
            means.append(img_tensor.mean(dim=(1,2)))
            stdevs.append(img_tensor.std(dim=(1,2)))
            self.image_tensors.append(img_tensor)
        self.mean = torch.stack(means, dim=0).mean(dim=0)
        self.stdev = torch.stack(stdevs, dim=0).mean(dim=0)
        self.normalizer = transforms.Normalize(self.mean, self.stdev)

        self.defect_to_one_hot = {y: x for x,y in enumerate(defect_codes)}
        self.one_hot_classes = [self.defect_to_one_hot[x] for x in self.classes]
        self.one_hot_classes = torch.nn.functional.one_hot(torch.Tensor(self.one_hot_classes).type(torch.int64))
        
    def __getitem__(self, idx):
        """Return specific index of dataset."""
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        return self.normalizer(self.image_tensors[idx]), self.one_hot_classes[idx]

In [75]:
root = os.path.abspath(os.path.join(os.getcwd(), ".."))
data_dir = os.path.join(root, "data")
aitex_dir = os.path.join(data_dir, "aitex")

data = AITEXClassification(aitex_dir, greyscale=True)
num_samples = len(data)
train_samples = int(num_samples * 0.9)
val_samples = num_samples - train_samples
train, val = random_split(data, [train_samples, val_samples])

bs = 4
train_loader = DataLoader(train, batch_size=bs, shuffle=True)
val_loader = DataLoader(val, batch_size=bs, shuffle=True)

### Train

In [76]:
model = resnet50(pretrained=True)
model.fc = torch.nn.Linear(model.fc.in_features, len(defect_codes))
torch.nn.init.xavier_uniform_(model.fc.weight)
# model.conv1.in_channels = 1
# model.conv1.weight.data.normal_(0, 0.001)
model.cuda()



ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [78]:
num_epochs = 10
device="cuda"
loss_fn = torch.nn.CrossEntropyLoss()
# accuracy_fn = Accuracy(task="multiclass")
optimizer = Adam(model.parameters(), lr=1e-2, weight_decay=1e-07)

for epoch in range(0, num_epochs):
    model.train()
    losses = torch.tensor(0.).cuda()
    accuracy = torch.tensor(0.).cuda()
    for x_in, y_in in train_loader:
        x = x_in.to(device)
        y = y_in.to(device)

        y_pred = model(x)
        loss = loss_fn(y_pred, y.type(torch.float32))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses += loss
        accuracy += (torch.argmax(y_pred, axis=1) == torch.argmax(y, axis=1)).sum()

    print(
        f"Epoch: {epoch}, Loss: {torch.round(losses / len(train), decimals=3)}, Accuracy: {torch.round(accuracy / len(train), decimals=3)}"
    )

Epoch: 0, Loss: 0.5070000290870667, Accuracy: 0.597000002861023
Epoch: 1, Loss: 0.38999998569488525, Accuracy: 0.6110000014305115
Epoch: 2, Loss: 0.38499999046325684, Accuracy: 0.6110000014305115
Epoch: 3, Loss: 0.3959999978542328, Accuracy: 0.6060000061988831
Epoch: 4, Loss: 0.39500001072883606, Accuracy: 0.6110000014305115
Epoch: 5, Loss: 0.3790000081062317, Accuracy: 0.6110000014305115
Epoch: 6, Loss: 0.3919999897480011, Accuracy: 0.597000002861023
Epoch: 7, Loss: 0.3919999897480011, Accuracy: 0.6110000014305115
Epoch: 8, Loss: 0.367000013589859, Accuracy: 0.6150000095367432
Epoch: 9, Loss: 0.3499999940395355, Accuracy: 0.6060000061988831
