In [36]:
import numpy as np
import pandas as pd
import torch
import wandb
import torchvision.transforms as T
import matplotlib.pyplot as plt
from wdd.data_handling.process_data import threshold_data
from huggingface_hub import hf_hub_download
from wdd.model.cnn_spp import CNN_SPP_Net
from wdd.data_handling.torch_dataset import WaferDataset

In [52]:
cnn_channels=(1,3,3)
spp_output_sizes=[(1,1),(3,3),(5,5)]
linear_dims=(20,9)
net=CNN_SPP_Net(cnn_channels,spp_output_sizes,linear_dims)

In [53]:
def init_weights(m):
    if isinstance(m, torch.nn.Linear):
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.10)

In [54]:
net.apply(init_weights)

  torch.nn.init.xavier_uniform(m.weight)


CNN_SPP_Net(
  (cnn_layers): Sequential(
    (conv2d0): Conv2d(1, 2, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (bnorm2d0): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (cnn-relu0): ReLU()
    (maxpool2d0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv2d1): Conv2d(2, 2, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (bnorm2d1): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (cnn-relu1): ReLU()
    (maxpool2d1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (linear_layers): Sequential(
    (linear0): Linear(in_features=68, out_features=9, bias=True)
  )
)

In [55]:
training_set=WaferDataset('/mnt/c/Users/lslat/Data/wafer_defect_detection_project/train.pkl')

In [56]:
validation_set=WaferDataset('/mnt/c/Users/lslat/Data/wafer_defect_detection_project/valid.pkl')

In [57]:

training_loader = torch.utils.data.DataLoader(training_set, batch_size=1, shuffle=True, num_workers=2)

In [58]:
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=1, shuffle=False, num_workers=2)

In [59]:
from torch.optim import Adam
 
# Define the loss function with Classification Cross-Entropy loss and an optimizer with Adam optimizer
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = Adam(net.parameters(), lr=0.0001, weight_decay=0.0001)

In [60]:
def train_one_epoch(epoch_index,batch_size):
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels = data
        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = net(inputs)

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels)
        loss.backward()
        

        # Adjust learning weights
        if i%batch_size==0:
            optimizer.step()
            optimizer.zero_grad()
            
        # Gather data and report
        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(training_loader) + i + 1
            running_loss = 0.
    return last_loss

In [61]:
# Initializing in a separate cell so we can easily add more epochs to the same run
epoch_number = 0
batch_size = 100
EPOCHS = 5

best_vloss = 1_000_000.

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    net.train(True)
    avg_loss = train_one_epoch(epoch_number,batch_size)

    # We don't need gradients on to do reporting
    net.train(False)

    running_vloss = 0.0
    for i, vdata in enumerate(validation_loader):
        vinputs, vlabels = vdata
        voutputs = net(vinputs)
        vloss = loss_fn(voutputs, vlabels)
        running_vloss += vloss

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
    epoch_number += 1

EPOCH 1:
  batch 1000 loss: 8.46663906556368
  batch 2000 loss: 8.057754670619964
  batch 3000 loss: 7.923253403544426
  batch 4000 loss: 7.5807909253537655
  batch 5000 loss: 7.253394916474819
  batch 6000 loss: 6.901358069390058
  batch 7000 loss: 6.916176226019859
  batch 8000 loss: 6.602479857683182


KeyboardInterrupt: 