In [1]:
import os
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import models
from torchvision import datasets 
from torch.utils.data import DataLoader
from torchvision.transforms import transforms
import matplotlib.pyplot as plt
os.environ['KMP_DUPLICATE_LIB_OK']="True"

In [2]:
class CNNModel(nn.Module):
    def __init__(self,in_ch,num_cl,hidden_ch=64):
        super(CNNModel,self).__init__()
        self.Model=nn.Sequential(
        self.block(in_ch,hidden_ch,final=False),
        self.block(hidden_ch,hidden_ch,final=False),
        nn.MaxPool2d(2,2),
        self.block(hidden_ch,hidden_ch,final=False),
        nn.MaxPool2d(2,2),
        self.block(hidden_ch,num_cl,final=True)
        )
        
    def block(self,in_ch,out_ch,final):           
        if not final:
            return nn.Sequential(
            nn.Conv2d(in_ch,out_ch,kernel_size=5,stride=1,padding=2),
            nn.BatchNorm2d(out_ch),
            nn.ReLU()
            )
        else:
            return nn.Sequential(
            nn.Conv2d(in_ch,out_ch,kernel_size=12,stride=1,padding=0),
            nn.Sigmoid()
            )
    def forward(self,x):
        x=self.Model(x)
        return x.reshape(x.shape[0],-1)

In [3]:
x=torch.randn(10,3,50,50)
model=CNNModel(3,2)
print(model(x).shape)

torch.Size([10, 2])


In [4]:
custom_transforms=transforms.Compose([              #data augmentation 
    transforms.RandomRotation(degrees=45),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor()])

dataset=datasets.ImageFolder("F:\BreastCancer_DataSet",transform=custom_transforms) 
data=DataLoader(dataset,batch_size=64,shuffle=True)                #training set
print(dataset)
dataset1=datasets.ImageFolder("F:\BreastCancer_DataSet_test",transform=transforms.ToTensor())
data1=DataLoader(dataset1,batch_size=64,shuffle=True)              #test_set

Dataset ImageFolder
    Number of datapoints: 6350
    Root location: F:\BreastCancer_DataSet
    StandardTransform
Transform: Compose(
               RandomRotation(degrees=[-45.0, 45.0], resample=False, expand=False)
               RandomHorizontalFlip(p=0.5)
               RandomVerticalFlip(p=0.5)
               ToTensor()
           )


In [5]:
epoch=1
criterion=nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=.001)
model.load_state_dict(torch.load(r"F:\weights_94.txt")) #for loading already trained model weights

def train_network(data,model):
    losses=[]
    a=[]
    for e in range(epoch):
        for images,labels in tqdm(data):
            preds=model(images)
            loss=criterion(preds,labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    torch.save(model.state_dict(), r"F:\weights.txt") #for saving state dictionary

def accuracy(data,model):
    correct=0
    total=0
    accuracy=[]
    model.eval()
    with torch.no_grad():
        for images,labels in tqdm(data):
            preds=model(images)
            values,index=preds.max(1)
            correct+=(index==labels).sum()
            total+=preds.size(0)
            acc=correct/total
    return acc.item()

def confusion_matrix(data,model):
    model.eval()
    indexes=torch.tensor([])
    labels=torch.tensor([])
    with torch.no_grad():
        for image,label in data:
            preds=model(image)
            values,index=preds.max(1)
            indexes=torch.cat((indexes,index),dim=0).int()
            labels=torch.cat((labels,label),dim=0).int()
            
        stacked=torch.stack((labels,indexes),dim=1)
        confusion=torch.zeros(2,2,dtype=torch.int32)
        
        for i in stacked:
            true,pred=i.tolist()
            confusion[true,pred]+=1
        print(confusion)

In [6]:
#train_network(data,model) 

In [7]:
accuracy(data1,model) #data1 for measuring test set accuracy

HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))




0.9023323655128479

In [8]:
confusion_matrix(data,model)

tensor([[3074,  210],
        [ 176, 2890]], dtype=torch.int32)
