In [1]:
import gc
import torch

torch.cuda.empty_cache()
torch.cuda.ipc_collect()
gc.collect()

import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from timm import create_model

class EnsembleNet(nn.Module):
    def __init__(self, num_classes=10, pretrained = True):
        super(EnsembleNet, self).__init__()
        self.num_classes = num_classes

        vgg = create_model("vit_large_patch32_224.orig_in21k", pretrained=pretrained, img_size=512)
        self.vgg = vgg

        inception = create_model('resnext101_32x8d', pretrained=pretrained)
        inception = nn.Sequential(*list(inception.children())[:-1])
        self.inception = inception

        resnet = create_model('efficientnet_b5', pretrained=pretrained)
        self.resnet = nn.Sequential(*list(resnet.children())[:-1])

        densenet = create_model('mobilenetv3_large_100', pretrained=pretrained)
        self.densenet = nn.Sequential(*list(densenet.children())[:-1])

        swin = create_model('swin_large_patch4_window7_224', pretrained=pretrained,features_only=True,img_size=512)
        self.swin = swin

        self.vgg_head = nn.Linear(1024, num_classes)
        self.google_head = nn.Linear(2048, num_classes)
        self.resnet_head = nn.Linear(2048, num_classes)
        self.dense_head = nn.Linear(1280, num_classes)
        self.swin_head = nn.Linear(1536, num_classes)

        self.fusion_net = nn.Sequential(
            nn.Linear(1024 + 2048 + 2048 + 1280 + 1536, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, num_classes)
        )
        
    def flatten(self, out):
        return F.adaptive_avg_pool2d(out, (1, 1)).squeeze(-1).squeeze(-1)

    def forward(self, x):
        vgg_feat = self.vgg(x)  # [B, 512, H/32, W/32]
        inception_feat = self.inception(x)
        resnet_feat = self.resnet(x)  # [B, 2048, H/32, W/32]
        dense_feat = self.densenet(x)  # [B, 1024, H/32, W/32]
        swin_feat = self.swin(x)[3].mean((1,2))
        
        # print(vgg_feat.shape, inception_feat.shape, resnet_feat.shape, dense_feat.shape)
        fused_feat = torch.cat([vgg_feat, inception_feat, resnet_feat, dense_feat, swin_feat], dim=1)

        vgg_out = self.vgg_head(vgg_feat)
        google_out = self.google_head(inception_feat)
        resnet_out = self.resnet_head(resnet_feat)
        dense_out = self.dense_head(dense_feat)
        fusion_out = self.fusion_net(fused_feat)
        swin_out = self.swin_head(swin_feat)

        return [fusion_out, vgg_out, google_out, resnet_out, dense_out, swin_out]


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from PIL import Image

import pandas as pd
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, path = "../input/cassava-leaf-disease-classification", transform = None, train = True):
        data = pd.read_csv(path + "/train.csv")
        fold = pd.read_csv(path + "/validation_data.csv")
        data = pd.merge(data, fold, on='image_id')
        indexes = data["fold"]>0 if train else data["fold"]==0
        data = data[indexes].reset_index(drop=True)
        self.images = [path+"/train_images/"+p for p in data["image_id"]]
        self.labels = [x for x in data["label"]]
        self.transform = transform
        # for path in self.images:
        #     img = Image.open(path).convert("RGB")
        #     torch.save(img, path.replace(".jpg", ".pt"))
        # self.images = [p.replace(".jpg", ".pt") for p in self.images]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        image = self.transform(Image.open(img_path).convert("RGB"))
        return image, self.labels[idx]

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import os



cuda = torch.cuda.is_available()
device = torch.device('cuda:0' if cuda else 'cpu')

transform = transforms.Compose([
    # transforms.RandomCrop((600,700)),
    transforms.Resize((512, 512)),     
    transforms.RandomHorizontalFlip(),      
    transforms.RandomRotation(10), 
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

train_dataset = MyDataset(transform=transform)
train_loader  = DataLoader(train_dataset, batch_size=2, shuffle=True)

transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

valid_dataset = MyDataset(transform=transform, train=False)
valid_loader  = DataLoader(valid_dataset, batch_size=1, shuffle=True)
model = EnsembleNet(num_classes=5).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
epoches = 500
current_epoch = 0



# checkpoint = torch.load('../output/trial_30/checkpoint.pth')
# model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# current_epoch = checkpoint['epoch']
# best_acc = checkpoint['best_acc']

In [4]:
# Training loop
path = f"../output/trial_{len(os.listdir("../output"))}"
os.mkdir(path)
path = path + "/"
best = 0
for epoch in range(current_epoch + 1, epoches):
    model.train()
    running_loss = 0.0
    correct, total = 0, 0

    if(epoch%5 == 0):
        torch.save(model.state_dict(), path + f'model{epoch}.pth')
        
    loop = tqdm(train_loader, desc=f"Epoch {epoch} training")
    for images, labels in loop:

        images, labels = images.to(device), labels.to(device)

        outputs = model(images)
        losses = criterion(outputs[0], labels)
        for i in range(1, len(outputs)):
            losses += criterion(outputs[i], labels)

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        running_loss += losses.item()*labels.size(0)
        output = outputs[0]
        for i in range(1, len(outputs)):
            output += outputs[i]
        _, predicted = torch.max(output, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    acc = 100 * correct / total
    running_loss = running_loss/total
    print(f"Epoch {epoch} Training, Loss: {running_loss:.3f}, Accuracy: {acc:.2f}%")

    model.eval()
    optimizer.zero_grad()
    running_loss = 0.0
    correct, total = 0, 0

    loop = tqdm(valid_loader, desc=f"Epoch {epoch} validation")
    with torch.no_grad():
        for images, labels in loop:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            losses = criterion(outputs[0], labels)
            for i in range(1, len(outputs)):
                losses += criterion(outputs[i], labels)

            running_loss += losses.item()
            output = outputs[0]
            for i in range(1, len(outputs)):
                output += outputs[i]
            _, predicted = torch.max(output, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    acc = 100 * correct / total
    running_loss = running_loss/total
    print(f"Epoch {epoch} Validation, Loss: {running_loss:.3f}, Accuracy: {acc:.2f}%")

    if(acc > best):
        best = acc
        torch.save(model.state_dict(), path + f'best.pth')
    torch.save({'epoch': epoch,'model_state_dict': model.state_dict()
                ,'optimizer_state_dict': optimizer.state_dict(),'best_acc': best}
                , path + 'checkpoint.pth')
    


Epoch 1 training: 100%|██████████| 8503/8503 [50:38<00:00,  2.80it/s]


Epoch 1 Training, Loss: 4.746, Accuracy: 76.79%


Epoch 1 validation: 100%|██████████| 4392/4392 [04:04<00:00, 17.98it/s]


Epoch 1 Validation, Loss: 4.423, Accuracy: 79.78%


Epoch 2 training: 100%|██████████| 8503/8503 [47:20<00:00,  2.99it/s]


Epoch 2 Training, Loss: 3.638, Accuracy: 82.49%


Epoch 2 validation: 100%|██████████| 4392/4392 [04:00<00:00, 18.27it/s]


Epoch 2 Validation, Loss: 3.348, Accuracy: 85.66%


Epoch 3 training: 100%|██████████| 8503/8503 [47:15<00:00,  3.00it/s]


Epoch 3 Training, Loss: 3.270, Accuracy: 84.19%


Epoch 3 validation: 100%|██████████| 4392/4392 [04:00<00:00, 18.25it/s]


Epoch 3 Validation, Loss: 3.098, Accuracy: 86.70%


Epoch 4 training: 100%|██████████| 8503/8503 [47:23<00:00,  2.99it/s]


Epoch 4 Training, Loss: 3.066, Accuracy: 85.52%


Epoch 4 validation: 100%|██████████| 4392/4392 [03:58<00:00, 18.40it/s]


Epoch 4 Validation, Loss: 3.991, Accuracy: 83.72%


Epoch 5 training: 100%|██████████| 8503/8503 [47:23<00:00,  2.99it/s]


Epoch 5 Training, Loss: 2.910, Accuracy: 86.62%


Epoch 5 validation: 100%|██████████| 4392/4392 [04:06<00:00, 17.85it/s]


Epoch 5 Validation, Loss: 3.834, Accuracy: 83.77%


Epoch 6 training: 100%|██████████| 8503/8503 [47:20<00:00,  2.99it/s]


Epoch 6 Training, Loss: 2.786, Accuracy: 87.13%


Epoch 6 validation: 100%|██████████| 4392/4392 [04:00<00:00, 18.26it/s]


Epoch 6 Validation, Loss: 4.567, Accuracy: 84.84%


Epoch 7 training: 100%|██████████| 8503/8503 [47:15<00:00,  3.00it/s]


Epoch 7 Training, Loss: 2.705, Accuracy: 87.86%


Epoch 7 validation: 100%|██████████| 4392/4392 [04:04<00:00, 17.98it/s]


Epoch 7 Validation, Loss: 14.246, Accuracy: 78.73%


Epoch 8 training: 100%|██████████| 8503/8503 [47:23<00:00,  2.99it/s]


Epoch 8 Training, Loss: 2.575, Accuracy: 88.27%


Epoch 8 validation: 100%|██████████| 4392/4392 [04:01<00:00, 18.16it/s]


Epoch 8 Validation, Loss: 3.060, Accuracy: 87.93%


Epoch 9 training: 100%|██████████| 8503/8503 [47:13<00:00,  3.00it/s]


Epoch 9 Training, Loss: 2.510, Accuracy: 88.70%


Epoch 9 validation: 100%|██████████| 4392/4392 [04:01<00:00, 18.17it/s]


Epoch 9 Validation, Loss: 4.537, Accuracy: 81.03%


Epoch 10 training: 100%|██████████| 8503/8503 [47:20<00:00,  2.99it/s]


Epoch 10 Training, Loss: 2.437, Accuracy: 89.03%


Epoch 10 validation: 100%|██████████| 4392/4392 [04:01<00:00, 18.20it/s]


Epoch 10 Validation, Loss: 3.013, Accuracy: 88.68%


Epoch 11 training: 100%|██████████| 8503/8503 [47:15<00:00,  3.00it/s]


Epoch 11 Training, Loss: 2.349, Accuracy: 89.91%


Epoch 11 validation: 100%|██████████| 4392/4392 [03:59<00:00, 18.32it/s]


Epoch 11 Validation, Loss: 3.194, Accuracy: 87.96%


Epoch 12 training: 100%|██████████| 8503/8503 [47:20<00:00,  2.99it/s]


Epoch 12 Training, Loss: 2.274, Accuracy: 90.42%


Epoch 12 validation: 100%|██████████| 4392/4392 [04:01<00:00, 18.18it/s]


Epoch 12 Validation, Loss: 3.041, Accuracy: 88.07%


Epoch 13 training: 100%|██████████| 8503/8503 [47:23<00:00,  2.99it/s]


Epoch 13 Training, Loss: 2.173, Accuracy: 91.04%


Epoch 13 validation: 100%|██████████| 4392/4392 [04:02<00:00, 18.14it/s]


Epoch 13 Validation, Loss: 5.557, Accuracy: 86.20%


Epoch 14 training: 100%|██████████| 8503/8503 [47:15<00:00,  3.00it/s]


Epoch 14 Training, Loss: 2.120, Accuracy: 91.50%


Epoch 14 validation: 100%|██████████| 4392/4392 [04:02<00:00, 18.15it/s]


Epoch 14 Validation, Loss: 3.228, Accuracy: 87.91%


Epoch 15 training: 100%|██████████| 8503/8503 [47:17<00:00,  3.00it/s]


Epoch 15 Training, Loss: 2.054, Accuracy: 92.04%


Epoch 15 validation: 100%|██████████| 4392/4392 [04:04<00:00, 17.99it/s]


Epoch 15 Validation, Loss: 3.207, Accuracy: 88.66%


Epoch 16 training: 100%|██████████| 8503/8503 [47:16<00:00,  3.00it/s]


Epoch 16 Training, Loss: 1.962, Accuracy: 92.60%


Epoch 16 validation: 100%|██████████| 4392/4392 [04:01<00:00, 18.18it/s]


Epoch 16 Validation, Loss: 3.385, Accuracy: 88.71%


Epoch 17 training: 100%|██████████| 8503/8503 [47:07<00:00,  3.01it/s] 


Epoch 17 Training, Loss: 1.912, Accuracy: 92.83%


Epoch 17 validation: 100%|██████████| 4392/4392 [04:00<00:00, 18.29it/s]


Epoch 17 Validation, Loss: 3.584, Accuracy: 87.98%


Epoch 18 training: 100%|██████████| 8503/8503 [47:11<00:00,  3.00it/s]


Epoch 18 Training, Loss: 1.852, Accuracy: 93.60%


Epoch 18 validation: 100%|██████████| 4392/4392 [03:58<00:00, 18.39it/s]


Epoch 18 Validation, Loss: 3.632, Accuracy: 88.11%


Epoch 19 training: 100%|██████████| 8503/8503 [47:19<00:00,  2.99it/s]


Epoch 19 Training, Loss: 1.777, Accuracy: 94.14%


Epoch 19 validation: 100%|██████████| 4392/4392 [03:59<00:00, 18.32it/s]


Epoch 19 Validation, Loss: 4.069, Accuracy: 87.27%


Epoch 20 training: 100%|██████████| 8503/8503 [47:18<00:00,  3.00it/s]


Epoch 20 Training, Loss: 1.719, Accuracy: 94.67%


Epoch 20 validation: 100%|██████████| 4392/4392 [04:01<00:00, 18.19it/s]


Epoch 20 Validation, Loss: 3.862, Accuracy: 87.91%


Epoch 21 training: 100%|██████████| 8503/8503 [47:19<00:00,  2.99it/s]


Epoch 21 Training, Loss: 1.669, Accuracy: 95.04%


Epoch 21 validation: 100%|██████████| 4392/4392 [04:01<00:00, 18.16it/s]


Epoch 21 Validation, Loss: 4.110, Accuracy: 87.07%


Epoch 22 training: 100%|██████████| 8503/8503 [47:19<00:00,  2.99it/s]


Epoch 22 Training, Loss: 1.621, Accuracy: 95.28%


Epoch 22 validation: 100%|██████████| 4392/4392 [03:59<00:00, 18.36it/s]


Epoch 22 Validation, Loss: 5.289, Accuracy: 85.34%


Epoch 23 training: 100%|██████████| 8503/8503 [47:18<00:00,  3.00it/s]


Epoch 23 Training, Loss: 1.569, Accuracy: 95.71%


Epoch 23 validation: 100%|██████████| 4392/4392 [04:03<00:00, 18.04it/s]


Epoch 23 Validation, Loss: 4.485, Accuracy: 87.59%


Epoch 24 training: 100%|██████████| 8503/8503 [47:07<00:00,  3.01it/s]


Epoch 24 Training, Loss: 1.537, Accuracy: 95.87%


Epoch 24 validation: 100%|██████████| 4392/4392 [03:59<00:00, 18.34it/s]


Epoch 24 Validation, Loss: 4.419, Accuracy: 88.37%


Epoch 25 training: 100%|██████████| 8503/8503 [47:08<00:00,  3.01it/s]


Epoch 25 Training, Loss: 1.491, Accuracy: 96.29%


Epoch 25 validation: 100%|██████████| 4392/4392 [04:02<00:00, 18.14it/s]


Epoch 25 Validation, Loss: 5.083, Accuracy: 86.84%


Epoch 26 training:  55%|█████▌    | 4709/8503 [27:06<21:50,  2.90it/s]


KeyboardInterrupt: 