In [None]:
%pip install kagglehub
%pip install tqdm

import kagglehub


# Download latest version
path = kagglehub.dataset_download("anasmohammedtahir/covidqu")

print("Path to dataset files:", path)

Using Colab cache for faster access to the 'covidqu' dataset.
Path to dataset files: /kaggle/input/covidqu


In [None]:
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from PIL import Image
import torchvision.transforms as T
import torch
import os


class SegmentationDataset(Dataset):
    def __init__(self, image_dir, infection_mask_dir, lung_mask_dir):
        super().__init__()
        self.image_dir = image_dir
        self.infection_mask_dir = infection_mask_dir
        self.lung_mask_dir = lung_mask_dir
        self.images  = sorted(os.listdir(image_dir))
        self.image_tf = T.Compose([
            T.Resize((256, 256)),
            T.ToTensor(),
        ])
        self.mask_tf = T.Compose([
            T.Resize((256, 256), interpolation=Image.NEAREST),
            T.ToTensor(),
        ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.images[idx])
        infection_mask_path = os.path.join(self.infection_mask_dir, self.images[idx])
        lung_mask_path = os.path.join(self.lung_mask_dir, self.images[idx])

        image = Image.open(image_path)
        infection_mask = Image.open(infection_mask_path)
        lung_mask = Image.open(lung_mask_path)

        image = self.image_tf(image)
        infection_mask = self.mask_tf(infection_mask)
        lung_mask = self.mask_tf(lung_mask)

        infection_mask.unsqueeze(0)
        lung_mask.unsqueeze(0)

        image = image.repeat(3,1,1)

        return image, torch.cat((infection_mask, lung_mask), 0)

train_dat = SegmentationDataset(
    "/kaggle/input/covidqu/Infection Segmentation Data/Infection Segmentation Data/Train/COVID-19/images",
    "/kaggle/input/covidqu/Infection Segmentation Data/Infection Segmentation Data/Train/COVID-19/infection masks",
    "/kaggle/input/covidqu/Infection Segmentation Data/Infection Segmentation Data/Train/COVID-19/lung masks",
)

val_dat = SegmentationDataset(
    "/kaggle/input/covidqu/Infection Segmentation Data/Infection Segmentation Data/Val/COVID-19/images",
    "/kaggle/input/covidqu/Infection Segmentation Data/Infection Segmentation Data/Val/COVID-19/infection masks",
    "/kaggle/input/covidqu/Infection Segmentation Data/Infection Segmentation Data/Val/COVID-19/lung masks",
)

In [None]:
train_loader = DataLoader(
    train_dat,
    batch_size=16,
    shuffle=True,
    num_workers=2,
)

val_loader = DataLoader(
    val_dat,
    batch_size=16,
    shuffle=True,
    num_workers=2,
)

In [None]:
import torchvision.models.segmentation as seg
from tqdm import tqdm

# device = "cuda" if torch.cuda.is_available else "cpu"

model = seg.deeplabv3_resnet50(
    weights=None,
    num_classes=2,
)

criterion = torch.nn.BCEWithLogitsLoss()
optimizer = AdamW(model.parameters(), lr=1e-4)

epochs = 100

train_losses = []

for epoch in range(1, epochs+1):
    model.train()
    for batch_idx, (image, masks) in enumerate(train_loader):
        out = model(image)["out"]
        loss = criterion(out, masks)

        train_losses.append(loss.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"epoch :{epoch} - batch id: {batch_idx} - loss: {loss.item()} - avg: {sum(train_losses)/len(train_losses)}")



# for idx, (input, infection_mask, lung_mask) in tqdm(enumerate(train_dat), len=len(train_dat)):



epoch :1 - batch id: 0 - loss: 0.6838056445121765 - avg: 0.6838056445121765
epoch :1 - batch id: 1 - loss: 0.6849275231361389 - avg: 0.6843665838241577
epoch :1 - batch id: 2 - loss: 0.680829644203186 - avg: 0.6831876039505005
epoch :1 - batch id: 3 - loss: 0.683901846408844 - avg: 0.6833661645650864
epoch :1 - batch id: 4 - loss: 0.6843580603599548 - avg: 0.6835645437240601
epoch :1 - batch id: 5 - loss: 0.6860566735267639 - avg: 0.6839798986911774
epoch :1 - batch id: 6 - loss: 0.6822748780250549 - avg: 0.6837363243103027
epoch :1 - batch id: 7 - loss: 0.6839419603347778 - avg: 0.6837620288133621
epoch :1 - batch id: 8 - loss: 0.6801130175590515 - avg: 0.6833565831184387
epoch :1 - batch id: 9 - loss: 0.6837793588638306 - avg: 0.683398860692978
epoch :1 - batch id: 10 - loss: 0.6841439604759216 - avg: 0.6834665970368818
epoch :1 - batch id: 11 - loss: 0.6822048425674438 - avg: 0.6833614508310953
epoch :1 - batch id: 12 - loss: 0.6811704635620117 - avg: 0.6831929133488581
epoch :1 - b