In [1]:
import numpy as np
import pandas as pd
import torch
from torchvision import transforms
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from torchvision.models import resnet50, ResNet50_Weights

### Data exploration

In [2]:
pickle_path = "/datasets/pick-it/dataset_3d_vnir_xrt.pkl"
data = pd.read_pickle(pickle_path)
labels = data["class"].unique().tolist()


In [3]:
class PickItDataset(Dataset):    
    def __init__(self, pickle_path):
        data = pd.read_pickle(pickle_path)
        cube = data["data_cube"].apply(lambda x: x.astype(np.float32))
        mask = data["masks"]
        self.data = cube
        self.masks = mask
        self.labels = data["class"].apply(lambda x: data["class"].unique().tolist().index(x))
        self.transform = transforms.Compose([Rescale((256,256)),Normalize()])

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            index = index.tolist()
        data = torch.from_numpy(self.data[idx])
        mask = torch.tensor(self.masks[idx])
        label = self.labels[idx]
        sample = {"data": data,"mask":mask, "label": label}
        if self.transform:
            sample = self.transform(sample)
        return sample


### Definition of the transforms

In [4]:
class Rescale(object):
    def __init__(self, scale):
        self.scale = scale

    def __call__(self, sample):
        data, mask ,label = sample["data"], sample["mask"],sample["label"]
        data = transforms.Resize(self.scale)(data)
        mask = transforms.Resize(self.scale)(mask.unsqueeze(0)).squeeze(0)
        return {"data": data,"mask":mask ,"label": label}

class Normalize(object):
    def __call__(self, sample):
        data, mask, label = sample["data"], sample["mask"], sample["label"]
        mean = torch.mean(data)
        std = torch.std(data)
        data = transforms.Normalize(mean,std)(data)
        return {"data": data,"mask":mask ,"label": label}



### ResNet50

In [5]:
class our_ResNet50(nn.Module):
    def __init__(self):
        super().__init__()
        self.res_net = resnet50()
        self.res_net.conv1 = nn.Conv2d(11, 64, kernel_size=7, stride=2, padding=3,bias=False)
        self.res_net.fc = nn.Linear(in_features = 2048 , out_features = 8,bias=True)
    def forward(self, x):
        x = self.res_net.forward(x)
        return x

In [6]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
net = our_ResNet50()
net.load_state_dict(torch.load("weigths/RES50_AUG_CE_INV.pt",map_location=torch.device(device)))
net.to(device)
print("Network loaded")

Network loaded


In [9]:
dataset = PickItDataset(pickle_path)

In [10]:
def predict(data, i):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    data = dataset[i]
    images, labels = data["data"].to(device), data["label"]
    images = images.unsqueeze(0)
    output = net(images).to(device)
    output = int(output.argmax())
    return output, labels

In [13]:
pred, truth = predict(dataset, 3)
print("Prediction:",labels[pred], ",Truth:",labels[truth])

Prediction: aluminium ,Truth: zinc
