In [None]:
from torch_snippets import *
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
from glob import glob
import torch,torch.nn as nn
from PIL import Image
import matplotlib.pyplot as plt
from torch import optim
device='cuda' if torch.cuda.is_available() else 'cpu'

class brain_GAN(Dataset):
    def __init__(self,folder):
        self.img=glob(folder+'\\kane.jpg')
        self.transform=transforms.Compose([transforms.Resize((64,64)),transforms.ToTensor(),transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])])
    def __len__(self):
        return len(self.img)
    def __getitem__(self,idx):
        im=self.img[idx]
        im=Image.open(im).convert('RGB')
        trans=self.transform(im)
        im=trans.clamp(0,1)
        return im.float().to(device)
        
#dataset
data=brain_GAN('C:\\Users\\debas\\Desktop')
print(data.__len__())
#dataloader
dl=DataLoader(data,batch_size=1)

#define generator

class generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model=nn.Sequential(
            nn.Linear(100,256),
            nn.LeakyReLU(0.2),
            nn.Linear(256,512),
            nn.LeakyReLU(0.2),
            nn.Linear(512,3*64*64),
            nn.Tanh()
        )
    def forward(self,x):
        return self.model(x)


#plt.imshow(data.__getitem__(1).permute(1,2,0))

#discriminator
class discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model=nn.Sequential(nn.Linear(3*64*64,512),
                                nn.LeakyReLU(0.2),
                                nn.Dropout(0.3),
                                nn.Linear(512,256),
                                nn.LeakyReLU(0.2),
                                nn.Dropout(0.2),
                                nn.Linear(256,1),
                                nn.Sigmoid()
                                )
    def forward(self,x):
        return self.model(x)


def noise(batch):
    return torch.randn(batch,100)

dis=discriminator().to(device)
gen=generator().to(device)
loss=nn.BCELoss()
optim_d=torch.optim.Adam(dis.parameters(),lr=0.0002)
optim_g=optim.Adam(gen.parameters(),lr=0.0002)

#epochs

for j in range(2000):
    print(f'epoch={j+1}')
    for i in dl:
       
        #train discriminator
        optim_d.zero_grad()
        real_data=i.to(device)
        print(real_data.shape)
        z=noise(real_data.size(0))
        fake_data=gen(z).to(device)
        fake_data=fake_data.detach()
        print(fake_data.shape)
        df=dis(fake_data)
        print(f'df={df}')
        loss_df=loss(df,torch.zeros(fake_data.size(0),1))
        loss_df.backward()
        dr=dis(real_data.view(real_data.size(0),-1))
        print(f'dr={dr}')
        loss_dr=loss(dr,torch.ones(real_data.size(0),1))
        loss_dr.backward()
        optim_d.step()
        total_loss_d=loss_dr+loss_df
        print(f'total_loss={total_loss_d}')
        #train generator
        optim_g.zero_grad()
        fake_data=gen(z).to(device)
        gf=dis(fake_data)
        loss_gf=loss(gf,torch.ones(real_data.size(0),1))
        loss_gf.backward()
        optim_g.step()
        print(f'loss_gf={loss_gf}')


torch.save({'dis':dis.to('cpu').state_dict(),'gen':gen.to('cpu').state_dict(),'optim_d':optim_d.state_dict(),'optim_g':optim_g.state_dict()},'GAN_brain_tumor.pth')



1
epoch=1
torch.Size([1, 3, 64, 64])
torch.Size([1, 12288])
df=tensor([[0.5023]], grad_fn=<SigmoidBackward0>)
dr=tensor([[0.5154]], grad_fn=<SigmoidBackward0>)
total_loss=1.3606363534927368
loss_gf=0.8299264311790466
epoch=2
torch.Size([1, 3, 64, 64])
torch.Size([1, 12288])
df=tensor([[0.4756]], grad_fn=<SigmoidBackward0>)
dr=tensor([[0.6548]], grad_fn=<SigmoidBackward0>)
total_loss=1.0689915418624878
loss_gf=0.8492233753204346
epoch=3
torch.Size([1, 3, 64, 64])
torch.Size([1, 12288])
df=tensor([[0.4678]], grad_fn=<SigmoidBackward0>)
dr=tensor([[0.8396]], grad_fn=<SigmoidBackward0>)
total_loss=0.8055168390274048
loss_gf=0.8593610525131226
epoch=4
torch.Size([1, 3, 64, 64])
torch.Size([1, 12288])
df=tensor([[0.4262]], grad_fn=<SigmoidBackward0>)
dr=tensor([[0.9045]], grad_fn=<SigmoidBackward0>)
total_loss=0.6559545993804932
loss_gf=0.949846625328064
epoch=5
torch.Size([1, 3, 64, 64])
torch.Size([1, 12288])
df=tensor([[0.3894]], grad_fn=<SigmoidBackward0>)
dr=tensor([[0.9812]], grad_fn=<