In [None]:
import numpy as np
import pandas as pd
import torch
import wandb
import torchvision.transforms as T
import matplotlib.pyplot as plt
from huggingface_hub import hf_hub_download
from torch.utils.data import WeightedRandomSampler
import torchmetrics

from sklearn.model_selection import train_test_split

from wdd.data_handling.process_data import threshold_data
from wdd.data_handling.pull_data import get_processed_data
from wdd.model.cnn_spp import CNN_SPP_Net,cnn_spp_hypDict
from wdd.data_handling.torch_dataset import WaferDataset
from wdd.data_handling.augment_data import wafer_train_transforms


In [None]:
cnn_channels=(1,3)
spp_output_sizes=[(1,1),(3,3)]
linear_dims=(9)
model_parameters={'cnn_channels':cnn_channels,'spp_output_sizes':spp_output_sizes,'linear_output_sizes':linear_dims}
net=CNN_SPP_Net(model_parameters)
net.init_weights()

In [None]:
train_df,test_df=get_processed_data()
#split train
train_df,valid_df=train_test_split(train_df, test_size=0.2,random_state=42)

In [None]:
training_set=WaferDataset(train_df,transform=wafer_train_transforms(0.0))
valid_set=WaferDataset(valid_df)

In [None]:
test_set=WaferDataset(test_df)

In [None]:
class_weights=torch.Tensor([1/training_set.len])*torch.Tensor([training_set.y.count(i) for i in range(9)])
assert(np.isclose(class_weights.sum(),1)),'class_weights must sum to be one'

sample_weights=torch.Tensor([1/class_weights[i] for i in training_set.y])

sampler=WeightedRandomSampler(weights=sample_weights,num_samples=len(sample_weights))

In [None]:
valid_class_weights=torch.Tensor([1/valid_set.len])*torch.Tensor([valid_set.y.count(i) for i in range(9)])
assert(np.isclose(valid_class_weights.sum(),1)),'valid_class_weights must sum to be one'

In [None]:
training_loader = torch.utils.data.DataLoader(training_set, batch_size=1 , num_workers=0,sampler=sampler)

In [None]:
valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=1, shuffle=True, num_workers=0)

In [None]:
test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, shuffle=True, num_workers=0)

In [None]:
from torch.optim import Adam
 
# Define the loss function with Classification Cross-Entropy loss and an optimizer with Adam optimizer
train_loss_fn = torch.nn.CrossEntropyLoss()
valid_loss_fn = torch.nn.CrossEntropyLoss(weight=valid_class_weights.reciprocal())
optimizer = Adam(net.parameters(), lr=0.001, weight_decay=0.0001)

In [None]:
from wdd.model.model_training import train_model

train_model(
    net,
    training_loader,
    valid_loader,
    train_loss_fn,
    valid_loss_fn,
    optimizer,
    epochs=1
)

In [None]:
y_trues,y_preds,y_pred_probs = net.predict(test_loader)

In [None]:
acc=torchmetrics.Accuracy(num_classes=9,average='micro')
bacc=torchmetrics.Accuracy(num_classes=9,average='macro')
by_class_acc=torchmetrics.Accuracy(num_classes=9,average='none')
f1=torchmetrics.F1Score(num_classes=9,average='micro')
bf1=torchmetrics.F1Score(num_classes=9,average='macro')
by_class_f1=torchmetrics.F1Score(num_classes=9,average='none')


In [None]:
acc(y_trues,y_preds)

In [None]:
by_class_acc(y_trues,y_preds)

In [None]:
f1(y_trues,y_preds)

In [None]:
bf1(y_trues,y_preds)

In [None]:
by_class_f1(y_trues,y_preds)

In [None]:
tuple(9*2**((i-1)) for i in range(5,0,-1))

In [None]:
(3-1)//2

In [None]:
[(1+2*i,1+2*i) for i in range(3)]

In [None]:
type(valid_loss_fn)

In [None]:
type(valid_loader)