In [2]:
import torch
import pandas as pd
import os

from torchvision.io import read_image
from torch.utils.data import Dataset

bean_annotations = pd.read_csv("../data/beans.csv")
bean_annotations.head()
DEFECT_CLASSES = dict(
    [
        (defect, index)
        for (index, defect) in enumerate(pd.unique(bean_annotations["defect_class"]))
    ]
)
DEFECT_CLASSES

{'burnt': 0,
 'normal': 1,
 'frag': 2,
 'under': 3,
 'quaker': 4,
 'insectOrMold': 5}

In [3]:
from torchvision.transforms import v2

transforms = {
    "train": v2.Compose(
        [
            v2.Resize(size=(400, 400)),
            v2.RandomHorizontalFlip(),
            v2.RandomRotation(
                degrees=(20, 340), fill=(255, 255, 255)
            ),  # Augment the data with random rotations, setting the background to white
            v2.ToDtype(
                torch.float32, scale=True
            ),  # Use float32 datatype to enable training on Apple silicon GPU
            v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    ),
    "test": v2.Compose(
        [
            v2.Resize(size=(400, 400)),
            v2.RandomRotation(
                degrees=(20, 340), fill=(255, 255, 255)
            ),  # Augment the data with random rotations, setting the background to white
            v2.ToDtype(torch.float32, scale=True),
            v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    ),
}

In [4]:
class RoastDefectsDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.bean_annotations = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.bean_annotations)

    def get_labels(self):
        return self.bean_annotations["defect_class"]

    def __getitem__(self, item):
        if torch.is_tensor(item):
            item = item.toList()

        img_name = self.bean_annotations.iloc[item, 0]
        img_dir = "-".join(img_name.split("-")[0:4])
        img_path = os.path.join(self.root_dir, img_dir, img_name)
        image = read_image(img_path)

        if self.transform:
            image = self.transform(image)

        annotations = DEFECT_CLASSES[
            self.bean_annotations.iloc[item, 1:]["defect_class"]
        ]
        return image, annotations, img_name

In [5]:
from sklearn.model_selection import train_test_split

train, test = train_test_split(
    bean_annotations,
    train_size=0.8,
    random_state=42,  # for repeatability
    stratify=bean_annotations["defect_class"],
)
train["defect_class"].value_counts()

defect_class
normal          1048
quaker           782
frag             237
under             83
burnt             40
insectOrMold      38
Name: count, dtype: int64

In [6]:
test["defect_class"].value_counts()

defect_class
normal          263
quaker          196
frag             59
under            21
burnt            10
insectOrMold      9
Name: count, dtype: int64

In [9]:
from torch.utils.data import DataLoader, SubsetRandomSampler
from torchsampler import ImbalancedDatasetSampler

data_train = RoastDefectsDataset(
    csv_file="../data/beans.csv",
    root_dir="../data/processed-whitebg",
    transform=transforms["train"],
)

data_test = RoastDefectsDataset(
    csv_file="../data/beans.csv",
    root_dir="../data/processed-whitebg",
    transform=transforms["test"],
)

train_sampler = ImbalancedDatasetSampler(
    data_train, labels=train["defect_class"], indices=list(train.index)
)
test_sampler = SubsetRandomSampler(list(test.index))


train_loader = DataLoader(data_train, sampler=train_sampler, batch_size=len(train))
test_loader = DataLoader(data_test, sampler=test_sampler, batch_size=len(train))

dataloaders = {"test": test_loader, "train": train_loader}
dataset_sizes = {"test": len(test), "train": len(train)}

In [55]:
def get_misclassified_results(model):
    model.to("mps")
    model.eval()
    all_labels = []
    all_preds = []
    with torch.no_grad():
        for inputs, labels, img_names in test_loader:
            outputs = model(inputs.to("mps"))
            _, preds = torch.max(outputs, 1)
            all_labels.extend(labels.tolist())
            all_preds.extend(preds.tolist())
    return [
        {
            "predicted": pd.unique(bean_annotations["defect_class"])[pred],
            "actual": pd.unique(bean_annotations["defect_class"])[actual],
            "filename": filename,
        }
        for pred, actual, filename in zip(all_preds, all_labels, img_names)
        if pred != actual
    ]

In [56]:
from torchvision import models

mobileNet_pt = models.mobilenet_v2()
mobileNet_no_pt = models.mobilenet_v2()
mobileNet_in_ftrs = mobileNet_no_pt.classifier[1].in_features
mobileNet_no_pt.classifier[1] = torch.nn.Linear(mobileNet_in_ftrs, len(DEFECT_CLASSES))
mobileNet_pt.classifier[1] = torch.nn.Linear(mobileNet_in_ftrs, len(DEFECT_CLASSES))
mobileNet_no_pt.load_state_dict(
    torch.load("trained-models/mobileNet-no-pretraining0-5gamma.pt")
)
mobileNet_pt.load_state_dict(
    torch.load("trained-models/mobilenet-pretrained-95-acc-36m16s-40epochs.pt")
)

resnet50_pt = models.resnet50()
resnet_in_ftrs = resnet50_pt.fc.in_features
resnet50_pt.fc = torch.nn.Linear(resnet_in_ftrs, len(DEFECT_CLASSES))
resnet50_pt.load_state_dict(torch.load("trained-models/resnet_50_91_acc.pt"))

<All keys matched successfully>

In [57]:
mobilenet_pt_results = get_misclassified_results(mobileNet_pt)

In [58]:
mobilenet_no_pt_results = get_misclassified_results(mobileNet_no_pt)

In [59]:
resnet50_pt_results = get_misclassified_results(resnet50_pt)

In [60]:
import json

results_to_dump = {
    "MobileNet (pretrained)": mobilenet_pt_results,
    "MobileNet (no pre training)": mobilenet_no_pt_results,
    "Resnet 50": resnet50_pt_results,
}

with open("missclassified_results.json", "w") as f:
    json.dump(results_to_dump, f)

# with open('my_dict.json') as f:
#     my_dict = json.load(f)

## Do the above for the KNN-2 classifier from the paper

In [7]:
transform_knn = v2.Compose(
    [
        v2.Resize(size=(400, 400)),
        v2.ToDtype(
            torch.float32, scale=True
        ),  # Use float32 datatype to enable training on Apple silicon GPU
        # v2.Grayscale(),
        # v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

In [21]:
dataset = RoastDefectsDataset(
    csv_file="../data/beans.csv",
    root_dir="../data/processed-whitebg",
    transform=transform_knn,
)

train_sampler = SubsetRandomSampler(list(train.index))
test_sampler = SubsetRandomSampler(list(test.index))


train_loader = DataLoader(dataset, sampler=train_sampler, batch_size=len(train))
test_loader = DataLoader(dataset, sampler=test_sampler, batch_size=len(test))

In [22]:
X_train = []
y_train = []
X_test = []
y_test = []
for _, (data, label, _) in enumerate(train_loader):
    X_train = data.squeeze()
    y_train = label.squeeze()

for _, (data, label, filenames) in enumerate(test_loader):
    X_test = data.squeeze()
    y_test = label.squeeze()

In [23]:
from skimage.exposure import histogram
from sklearn.neighbors import KNeighborsClassifier

X_train_hists = [
    histogram(img.numpy(), nbins=256, channel_axis=0)[0].ravel() for img in X_train
]
X_test_hists = [
    histogram(img.numpy(), nbins=256, channel_axis=0)[0].ravel() for img in X_test
]

knn = KNeighborsClassifier(n_neighbors=6, metric="canberra")

In [24]:
knn.fit(X_train_hists, y_train)
preds = knn.predict(X_test_hists)

In [25]:
knn_2_missclassifications = [
    {
        "predicted": pd.unique(bean_annotations["defect_class"])[pred],
        "actual": pd.unique(bean_annotations["defect_class"])[actual],
        "filename": filename,
    }
    for pred, actual, filename in zip(preds, y_test, filenames)
    if pred != actual
]

In [30]:
import json

with open("missclassified_results.json") as f:
    all_results = json.load(f)
all_results["KNN-2"] = knn_2_missclassifications

{'MobileNet (pretrained)': [{'predicted': 'normal',
   'actual': 'frag',
   'filename': 'brazil-wushwush-nat-frag-1-16.png'},
  {'predicted': 'normal',
   'actual': 'quaker',
   'filename': 'kenya-sl28-washed-quaker-6-22.png'},
  {'predicted': 'quaker',
   'actual': 'frag',
   'filename': 'ethiopia-yirga-CM-frag-0-12.png'},
  {'predicted': 'quaker',
   'actual': 'under',
   'filename': 'brazil-catuai-nat-under-2-1.png'},
  {'predicted': 'quaker',
   'actual': 'frag',
   'filename': 'brazil-wushwush-nat-frag-0-24.png'},
  {'predicted': 'quaker',
   'actual': 'frag',
   'filename': 'columbia-caturra-washed-frag-1-1.png'},
  {'predicted': 'frag',
   'actual': 'quaker',
   'filename': 'brazil-catuai-nat-quaker-3-2.png'},
  {'predicted': 'quaker',
   'actual': 'burnt',
   'filename': 'ethiopia-yirga-CM-burnt-1-6.png'},
  {'predicted': 'normal',
   'actual': 'quaker',
   'filename': 'kenya-sl28-washed-quaker-2-0.png'},
  {'predicted': 'normal',
   'actual': 'quaker',
   'filename': 'kenya-sl

In [32]:
with open("missclassified_results.json", "w") as f:
    json.dump(all_results, f)