In [1]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import models
from torchvision import datasets 
from torch.utils.data import DataLoader
from torchvision.transforms import transforms

In [2]:
trans1=transforms.Compose([transforms.Resize((227,227)),transforms.ToTensor()])
trans2=transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor()])
dataset1=datasets.ImageFolder("F:\BreastCancer_DataSet_Val",transform=trans1)
dataset2=datasets.ImageFolder("F:\BreastCancer_DataSet_Val",transform=trans2)
data1=DataLoader(dataset1,batch_size=512,shuffle=True) #AlexNet accept 227x227 images
data2=DataLoader(dataset2,batch_size=128,shuffle=True) #VGG-16, GoogLenet accept 224x224 images

In [3]:
alexnet=models.alexnet(pretrained=True)
for param in alexnet.parameters():
    param.requires_grad=False     #to stop training of the first few layers
alexnet.classifier=nn.Sequential(nn.Linear(9216,2),nn.Sigmoid()) #only trainable layer

vgg16=models.vgg16(pretrained=True)
for param in vgg16.parameters():
    param.requires_grad=False
vgg16.classifier=nn.Sequential(nn.Linear(25088,2),nn.Sigmoid())

googlenet=models.googlenet(pretrained=True)
for param in googlenet.parameters():
    param.requires_grad=False
googlenet.fc=nn.Sequential(nn.Linear(1024,2),nn.Sigmoid())

In [4]:
epoch=1
criterion=nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(alexnet.parameters(),lr=.001)

def accuracy(data,model):
    correct=0
    total=0
    model.eval()
    with torch.no_grad():
        for images,labels in tqdm(data):
            preds=model(images)
            values,index=preds.max(1)
            correct+=(index==labels).sum()
            total+=preds.size(0)
            acc=correct/total
        print('accuracy',acc.item()*100,'%')

def train_network(data,model):
    for e in range(epoch):
        for images,labels in tqdm(data):
            preds=model(images)
            loss=criterion(preds,labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    accuracy(data,model)

In [5]:
train_network(data1,alexnet)

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))


accuracy 87.63672113418579 %


In [6]:
train_network(data2,vgg16)

HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


accuracy 49.94140565395355 %


In [7]:
train_network(data2,googlenet)

HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


accuracy 63.749998807907104 %
