In [26]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms

In [27]:
#from google.colab import drive
#drive.mount('/content/drive')
torch.cuda.is_available()

False

In [28]:
def get_chest_images():
    TRANSFORM_IMG = transforms.Compose([transforms.Resize((256,256)), transforms.ToTensor()])
    file_path = "/storage/pipemon/6.867-xray-project/data/"
    #define train data loader
    trainset = datasets.ImageFolder(root=file_path+"train/", transform=TRANSFORM_IMG)
    train_loader = torch.utils.data.DataLoader(trainset, batch_size = 64, shuffle=True, num_workers=2)
    #define test data loader
    valset = datasets.ImageFolder(root=file_path+"test/", transform=TRANSFORM_IMG)
    val_loader = torch.utils.data.DataLoader(valset, batch_size= 64, shuffle=False, num_workers=2)
    return train_loader, val_loader
    

In [29]:
class Chest_Disease_Net(nn.Module):
    def __init__(self):
        super(Chest_Disease_Net, self).__init__()
        
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=16, stride=3, padding=0),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=5, stride=2)
        )

        self.layer2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=7, stride=1, padding=0),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=5, stride=2)
        )

        self.layer3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )
        
        self.layer4 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )

        self.layer5 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2)
        )

        self.fc1 = nn.Sequential(
            nn.Linear(12544, 1024),
            nn.ReLU(),
        )
        self.fc2 = nn.Sequential(
            nn.Linear(1024, 512),
            nn.ReLU()
        )
        self.fc3 = nn.Linear(512,2)
         
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc1(out)
        out = self.fc2(out)
        out = self.fc3(out)
        return(out)

In [None]:
def run():
    # Parameters
    num_epochs = 10

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = Chest_Disease_Net()
    model = model.to(device)

    trainloader, valloader = get_chest_images()

    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

    trainloader, valloader = get_chest_images()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, (inputs, labels) in enumerate(tqdm(trainloader), 1):
            inputs = inputs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print("Loss: " + str(running_loss/len(trainloader)))
        # save after every epoch
        torch.save(model.state_dict(), "model.%d" % epoch)

        model.eval()

        train_correct = 0
        train_total = 0
        with torch.no_grad():
            for data in tqdm(trainloader):
                images, labels = data
                images = images.to(device)
                labels = labels.to(device)

                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)

                train_total += labels.size(0)

                train_correct += (predicted == labels).sum().item()
        print('Top One Error of the network on train images: %d %%' % (
                100 * (1 - train_correct / train_total)))


        correct = 0
        val_total = 0
        with torch.no_grad():
            for data in tqdm(valloader):
                images, labels = data

                images = images.to(device)
                labels = labels.to(device)

                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)

                val_total += labels.size(0)

                correct += (predicted == labels).sum().item()

        print('Top One Error of the network on validation images: %d %%' % (
                100 * (1 - correct / val_total)))

        gc.collect()
run()


  0%|          | 0/1352 [00:00<?, ?it/s][A
  0%|          | 1/1352 [00:03<1:07:51,  3.01s/it][A
  0%|          | 2/1352 [00:04<56:12,  2.50s/it]  [A
  0%|          | 3/1352 [00:05<47:42,  2.12s/it][A
  0%|          | 4/1352 [00:06<42:00,  1.87s/it][A
  0%|          | 5/1352 [00:08<37:41,  1.68s/it][A
  0%|          | 6/1352 [00:09<34:37,  1.54s/it][A
  1%|          | 7/1352 [00:10<32:42,  1.46s/it][A
  1%|          | 8/1352 [00:11<31:19,  1.40s/it][A
  1%|          | 9/1352 [00:13<30:11,  1.35s/it][A
  1%|          | 10/1352 [00:14<29:37,  1.32s/it][A
  1%|          | 11/1352 [00:15<28:51,  1.29s/it][A
  1%|          | 12/1352 [00:16<28:36,  1.28s/it][A
  1%|          | 13/1352 [00:18<28:11,  1.26s/it][A
  1%|          | 14/1352 [00:19<27:57,  1.25s/it][A
  1%|          | 15/1352 [00:20<27:41,  1.24s/it][A
  1%|          | 16/1352 [00:21<27:27,  1.23s/it][A
  1%|▏         | 17/1352 [00:22<27:42,  1.25s/it][A
  1%|▏         | 18/1352 [00:24<27:20,  1.23s/it][A
  1%|▏

 22%|██▏       | 304/1352 [06:18<21:42,  1.24s/it][A
 23%|██▎       | 305/1352 [06:20<21:30,  1.23s/it][A
 23%|██▎       | 306/1352 [06:21<21:27,  1.23s/it][A
 23%|██▎       | 307/1352 [06:22<21:25,  1.23s/it][A
 23%|██▎       | 308/1352 [06:23<21:28,  1.23s/it][A
 23%|██▎       | 309/1352 [06:24<21:33,  1.24s/it][A
 23%|██▎       | 310/1352 [06:26<21:30,  1.24s/it][A
 23%|██▎       | 311/1352 [06:27<21:20,  1.23s/it][A
 23%|██▎       | 312/1352 [06:28<21:21,  1.23s/it][A
 23%|██▎       | 313/1352 [06:29<21:27,  1.24s/it][A
 23%|██▎       | 314/1352 [06:31<21:23,  1.24s/it][A
 23%|██▎       | 315/1352 [06:32<21:12,  1.23s/it][A
 23%|██▎       | 316/1352 [06:33<21:11,  1.23s/it][A
 23%|██▎       | 317/1352 [06:34<21:02,  1.22s/it][A
 24%|██▎       | 318/1352 [06:36<21:01,  1.22s/it][A
 24%|██▎       | 319/1352 [06:37<21:06,  1.23s/it][A
 24%|██▎       | 320/1352 [06:38<21:16,  1.24s/it][A
 24%|██▎       | 321/1352 [06:39<21:12,  1.23s/it][A
 24%|██▍       | 322/1352 [0

 45%|████▍     | 606/1352 [12:31<15:06,  1.21s/it][A
 45%|████▍     | 607/1352 [12:32<15:00,  1.21s/it][A
 45%|████▍     | 608/1352 [12:33<14:52,  1.20s/it][A
 45%|████▌     | 609/1352 [12:34<14:38,  1.18s/it][A
 45%|████▌     | 610/1352 [12:36<14:53,  1.20s/it][A
 45%|████▌     | 611/1352 [12:37<14:56,  1.21s/it][A
 45%|████▌     | 612/1352 [12:38<14:57,  1.21s/it][A
 45%|████▌     | 613/1352 [12:39<15:02,  1.22s/it][A
 45%|████▌     | 614/1352 [12:41<15:03,  1.22s/it][A
 45%|████▌     | 615/1352 [12:42<15:04,  1.23s/it][A
 46%|████▌     | 616/1352 [12:43<15:04,  1.23s/it][A
 46%|████▌     | 617/1352 [12:44<14:51,  1.21s/it][A
 46%|████▌     | 618/1352 [12:45<14:50,  1.21s/it][A
 46%|████▌     | 619/1352 [12:47<14:57,  1.22s/it][A
 46%|████▌     | 620/1352 [12:48<14:57,  1.23s/it][A
 46%|████▌     | 621/1352 [12:49<15:03,  1.24s/it][A
 46%|████▌     | 622/1352 [12:50<15:16,  1.26s/it][A
 46%|████▌     | 623/1352 [12:52<15:13,  1.25s/it][A
 46%|████▌     | 624/1352 [1

 67%|██████▋   | 908/1352 [18:45<09:11,  1.24s/it][A
 67%|██████▋   | 909/1352 [18:46<09:11,  1.24s/it][A
 67%|██████▋   | 910/1352 [18:47<09:09,  1.24s/it][A
 67%|██████▋   | 911/1352 [18:48<09:08,  1.24s/it][A
 67%|██████▋   | 912/1352 [18:49<09:06,  1.24s/it][A
 68%|██████▊   | 913/1352 [18:51<08:59,  1.23s/it][A
 68%|██████▊   | 914/1352 [18:52<08:57,  1.23s/it][A
 68%|██████▊   | 915/1352 [18:53<09:00,  1.24s/it][A
 68%|██████▊   | 916/1352 [18:54<09:02,  1.24s/it][A
 68%|██████▊   | 917/1352 [18:56<09:01,  1.25s/it][A
 68%|██████▊   | 918/1352 [18:57<09:00,  1.25s/it][A
 68%|██████▊   | 919/1352 [18:58<09:02,  1.25s/it][A
 68%|██████▊   | 920/1352 [18:59<09:04,  1.26s/it][A
 68%|██████▊   | 921/1352 [19:01<08:59,  1.25s/it][A
 68%|██████▊   | 922/1352 [19:02<09:02,  1.26s/it][A
 68%|██████▊   | 923/1352 [19:03<08:55,  1.25s/it][A
 68%|██████▊   | 924/1352 [19:04<08:52,  1.24s/it][A
 68%|██████▊   | 925/1352 [19:06<08:48,  1.24s/it][A
 68%|██████▊   | 926/1352 [1

 89%|████████▉ | 1206/1352 [24:54<03:00,  1.23s/it][A
 89%|████████▉ | 1207/1352 [24:55<02:58,  1.23s/it][A
 89%|████████▉ | 1208/1352 [24:56<02:58,  1.24s/it][A
 89%|████████▉ | 1209/1352 [24:58<02:55,  1.23s/it][A
 89%|████████▉ | 1210/1352 [24:59<02:54,  1.23s/it][A
 90%|████████▉ | 1211/1352 [25:00<02:54,  1.24s/it][A
 90%|████████▉ | 1212/1352 [25:01<02:51,  1.22s/it][A
 90%|████████▉ | 1213/1352 [25:02<02:48,  1.22s/it][A
 90%|████████▉ | 1214/1352 [25:04<02:47,  1.21s/it][A
 90%|████████▉ | 1215/1352 [25:05<02:46,  1.22s/it][A
 90%|████████▉ | 1216/1352 [25:06<02:46,  1.22s/it][A
 90%|█████████ | 1217/1352 [25:07<02:44,  1.22s/it][A
 90%|█████████ | 1218/1352 [25:09<02:44,  1.23s/it][A
 90%|█████████ | 1219/1352 [25:10<02:44,  1.23s/it][A
 90%|█████████ | 1220/1352 [25:11<02:43,  1.24s/it][A
 90%|█████████ | 1221/1352 [25:12<02:40,  1.23s/it][A
 90%|█████████ | 1222/1352 [25:14<02:40,  1.23s/it][A
 90%|█████████ | 1223/1352 [25:15<02:40,  1.24s/it][A
 91%|█████

Loss: 0.6696971963936761



  0%|          | 1/1352 [00:02<50:24,  2.24s/it][A
  0%|          | 2/1352 [00:02<38:13,  1.70s/it][A
  0%|          | 3/1352 [00:03<33:17,  1.48s/it][A
  0%|          | 4/1352 [00:04<26:22,  1.17s/it][A
  0%|          | 5/1352 [00:05<25:42,  1.15s/it][A
  0%|          | 6/1352 [00:05<20:57,  1.07it/s][A
  1%|          | 7/1352 [00:06<22:10,  1.01it/s][A
  1%|          | 8/1352 [00:07<18:26,  1.21it/s][A
  1%|          | 9/1352 [00:08<20:44,  1.08it/s][A
  1%|          | 10/1352 [00:08<17:25,  1.28it/s][A
  1%|          | 11/1352 [00:09<19:46,  1.13it/s][A
  1%|          | 12/1352 [00:10<16:39,  1.34it/s][A
  1%|          | 13/1352 [00:11<19:24,  1.15it/s][A
  1%|          | 14/1352 [00:11<16:27,  1.35it/s][A
  1%|          | 15/1352 [00:13<19:04,  1.17it/s][A
  1%|          | 16/1352 [00:13<16:14,  1.37it/s][A
  1%|▏         | 17/1352 [00:14<19:02,  1.17it/s][A
  1%|▏         | 18/1352 [00:15<16:11,  1.37it/s][A
  1%|▏         | 19/1352 [00:16<19:19,  1.15it/s][A
 

 23%|██▎       | 305/1352 [04:11<14:35,  1.20it/s][A
 23%|██▎       | 306/1352 [04:11<13:11,  1.32it/s][A
 23%|██▎       | 307/1352 [04:12<15:30,  1.12it/s][A
 23%|██▎       | 308/1352 [04:13<13:18,  1.31it/s][A
 23%|██▎       | 309/1352 [04:14<16:17,  1.07it/s][A
 23%|██▎       | 310/1352 [04:15<13:58,  1.24it/s][A
 23%|██▎       | 311/1352 [04:16<16:59,  1.02it/s][A
 23%|██▎       | 312/1352 [04:16<14:12,  1.22it/s][A
 23%|██▎       | 313/1352 [04:18<17:00,  1.02it/s][A
 23%|██▎       | 314/1352 [04:18<14:24,  1.20it/s][A
 23%|██▎       | 315/1352 [04:20<16:20,  1.06it/s][A
 23%|██▎       | 316/1352 [04:20<13:42,  1.26it/s][A
 23%|██▎       | 317/1352 [04:21<15:25,  1.12it/s][A
 24%|██▎       | 318/1352 [04:22<13:02,  1.32it/s][A
 24%|██▎       | 319/1352 [04:23<14:50,  1.16it/s][A
 24%|██▎       | 320/1352 [04:23<12:44,  1.35it/s][A
 24%|██▎       | 321/1352 [04:24<14:17,  1.20it/s][A
 24%|██▍       | 322/1352 [04:25<12:08,  1.41it/s][A
 24%|██▍       | 323/1352 [0

 45%|████▍     | 607/1352 [08:19<10:37,  1.17it/s][A
 45%|████▍     | 608/1352 [08:19<09:04,  1.37it/s][A
 45%|████▌     | 609/1352 [08:20<10:39,  1.16it/s][A
 45%|████▌     | 610/1352 [08:21<09:02,  1.37it/s][A
 45%|████▌     | 611/1352 [08:22<10:30,  1.18it/s][A
 45%|████▌     | 612/1352 [08:22<08:54,  1.39it/s][A
 45%|████▌     | 613/1352 [08:24<10:47,  1.14it/s][A
 45%|████▌     | 614/1352 [08:24<09:07,  1.35it/s][A
 45%|████▌     | 615/1352 [08:25<10:53,  1.13it/s][A
 46%|████▌     | 616/1352 [08:26<09:12,  1.33it/s][A
 46%|████▌     | 617/1352 [08:27<10:58,  1.12it/s][A
 46%|████▌     | 618/1352 [08:27<09:15,  1.32it/s][A
 46%|████▌     | 619/1352 [08:29<11:31,  1.06it/s][A
 46%|████▌     | 620/1352 [08:29<09:43,  1.25it/s][A
 46%|████▌     | 621/1352 [08:31<11:30,  1.06it/s][A
 46%|████▌     | 622/1352 [08:31<09:45,  1.25it/s][A
 46%|████▌     | 623/1352 [08:32<11:01,  1.10it/s][A
 46%|████▌     | 624/1352 [08:33<09:21,  1.30it/s][A
 46%|████▌     | 625/1352 [0

 67%|██████▋   | 909/1352 [12:24<06:41,  1.10it/s][A
 67%|██████▋   | 910/1352 [12:25<05:37,  1.31it/s][A
 67%|██████▋   | 911/1352 [12:26<07:01,  1.05it/s][A
 67%|██████▋   | 912/1352 [12:27<05:54,  1.24it/s][A
 68%|██████▊   | 913/1352 [12:28<07:02,  1.04it/s][A
 68%|██████▊   | 914/1352 [12:29<05:52,  1.24it/s][A
 68%|██████▊   | 915/1352 [12:30<07:10,  1.02it/s][A
 68%|██████▊   | 916/1352 [12:30<05:58,  1.22it/s][A
 68%|██████▊   | 917/1352 [12:32<07:12,  1.01it/s][A
 68%|██████▊   | 918/1352 [12:32<05:57,  1.21it/s][A
 68%|██████▊   | 919/1352 [12:33<06:39,  1.08it/s][A
 68%|██████▊   | 920/1352 [12:34<05:39,  1.27it/s][A
 68%|██████▊   | 921/1352 [12:35<06:21,  1.13it/s][A
 68%|██████▊   | 922/1352 [12:35<05:24,  1.32it/s][A
 68%|██████▊   | 923/1352 [12:37<06:11,  1.15it/s][A
 68%|██████▊   | 924/1352 [12:37<05:14,  1.36it/s][A
 68%|██████▊   | 925/1352 [12:38<05:56,  1.20it/s][A
 68%|██████▊   | 926/1352 [12:38<05:02,  1.41it/s][A
 69%|██████▊   | 927/1352 [1

 89%|████████▉ | 1207/1352 [16:29<02:03,  1.17it/s][A
 89%|████████▉ | 1208/1352 [16:30<01:55,  1.24it/s][A
 89%|████████▉ | 1209/1352 [16:31<01:56,  1.23it/s][A
 89%|████████▉ | 1210/1352 [16:32<01:55,  1.23it/s][A
 90%|████████▉ | 1211/1352 [16:32<01:50,  1.27it/s][A
 90%|████████▉ | 1212/1352 [16:33<02:00,  1.16it/s][A
 90%|████████▉ | 1213/1352 [16:34<01:47,  1.29it/s][A
 90%|████████▉ | 1214/1352 [16:35<01:59,  1.15it/s][A
 90%|████████▉ | 1215/1352 [16:35<01:40,  1.37it/s][A
 90%|████████▉ | 1216/1352 [16:37<01:59,  1.14it/s][A
 90%|█████████ | 1217/1352 [16:37<01:40,  1.35it/s][A
 90%|█████████ | 1218/1352 [16:38<02:03,  1.09it/s][A
 90%|█████████ | 1219/1352 [16:39<01:43,  1.29it/s][A
 90%|█████████ | 1220/1352 [16:40<02:01,  1.08it/s][A
 90%|█████████ | 1221/1352 [16:40<01:41,  1.30it/s][A
 90%|█████████ | 1222/1352 [16:42<02:02,  1.06it/s][A
 90%|█████████ | 1223/1352 [16:42<01:39,  1.29it/s][A
 91%|█████████ | 1224/1352 [16:44<02:02,  1.05it/s][A
 91%|█████

Top One Error of the network on train images: 36 %



  0%|          | 1/400 [00:02<14:33,  2.19s/it][A
  0%|          | 2/400 [00:02<11:04,  1.67s/it][A
  1%|          | 3/400 [00:03<10:06,  1.53s/it][A
  1%|          | 4/400 [00:04<07:57,  1.21s/it][A
  1%|▏         | 5/400 [00:05<07:43,  1.17s/it][A
  2%|▏         | 6/400 [00:05<06:15,  1.05it/s][A
  2%|▏         | 7/400 [00:06<06:30,  1.01it/s][A
  2%|▏         | 8/400 [00:07<05:24,  1.21it/s][A
  2%|▏         | 9/400 [00:08<05:59,  1.09it/s][A
  2%|▎         | 10/400 [00:08<05:02,  1.29it/s][A
  3%|▎         | 11/400 [00:10<05:41,  1.14it/s][A
  3%|▎         | 12/400 [00:10<04:46,  1.35it/s][A
  3%|▎         | 13/400 [00:11<05:32,  1.16it/s][A
  4%|▎         | 14/400 [00:12<04:53,  1.32it/s][A
  4%|▍         | 15/400 [00:13<05:19,  1.20it/s][A
  4%|▍         | 16/400 [00:13<04:52,  1.31it/s][A
  4%|▍         | 17/400 [00:14<05:27,  1.17it/s][A
  4%|▍         | 18/400 [00:15<04:58,  1.28it/s][A
  5%|▍         | 19/400 [00:16<05:12,  1.22it/s][A
  5%|▌         | 20/