In [2]:
import torch
import torch.nn as nn
from torchvision import transforms
import os
import pandas as pd
import math
import random
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

In [5]:
BASE_DIR = "dataset/Air Pollution Image Dataset/Air Pollution Image Dataset/Combined_Dataset"
IMG_DIR = os.path.join(BASE_DIR, "All_img")
CSV_DIR = os.path.join(BASE_DIR, "IND_and_Nep_AQI_Dataset.csv")
NEGATIVE_DATASET_IMG_DIR = "/home/mukulboro/python_projects/deblur/dataset/flickr30k_images"
random.seed(42)

In [6]:
negative_images = os.listdir(NEGATIVE_DATASET_IMG_DIR)
random.shuffle(negative_images)

In [7]:
# Only need image name and AQI
df = pd.read_csv(CSV_DIR)
df = df[["Filename", "AQI"]]
df = df.rename(columns={"Filename": "filename", "AQI": "aqi"})
df.head()

Unnamed: 0,filename,aqi
0,BRI_Un_2023-02-02- 12.00-9.jpg,158
1,BRI_Un_2023-02-02- 12.00-8.jpg,158
2,BRI_Un_2023-02-02- 12.00-7.jpg,158
3,BRI_Un_2023-02-02- 12.00-6.jpg,158
4,BRI_Un_2023-02-02- 12.00-5.jpg,158


In [8]:
aqi = df["aqi"]
min(aqi), max(aqi)

(15, 450)

In [5]:
negative_images = negative_images[:len(df)]
negative_df = pd.DataFrame({
    "filename": negative_images,
    "aqi": -1
})

combined_df = pd.concat([df, negative_df], ignore_index=True)
combined_df = combined_df.sample(frac=1, random_state=42).reset_index(drop=True) # Randomly shuffle dataframe
combined_df

Unnamed: 0,filename,aqi
0,BENGR_Mod_2023-03-13-08.30-1-120.jpg,53
1,4543597747.jpg,-1
2,4549048831.jpg,-1
3,BIR_UNFSG_VF_2023-02-03-15.00-3-44.jpg,141
4,2066241589.jpg,-1
...,...,...
24475,6077121925.jpg,-1
24476,MH_UN_2023-03-15-15.00-1-281.jpg,162
24477,BIR_UNH_VF_2023-02-03- 10.00-2-38.jpg,163
24478,5004253043.jpg,-1


In [6]:
class AQIDataset(Dataset):
    def __init__(self, dataframe=combined_df, train=True, test=False, train_ratio=0.8):
        super().__init__()
        total_length = len(dataframe)
        if train == True and test == False:
            self.df = dataframe[:math.floor(train_ratio*total_length)]
        elif train == False and test == True:
            self.df = dataframe[math.floor(train_ratio*total_length):]
        else:
            raise ValueError("Invalid test and train combination")
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ])
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        current_item = self.df.iloc[index]
        if current_item["aqi"] == -1:
            img_dir = NEGATIVE_DATASET_IMG_DIR
        else:
            img_dir = IMG_DIR
        
        image = Image.open(os.path.join(img_dir, current_item["filename"]))
        image = self.transform(image)
        return {
            "aqi": current_item["aqi"],
            "image": image,
            "is_sky": 0 if current_item["aqi"] == -1 else 1
        }

In [7]:
class MultiDomainLoss(nn.Module):
    def __init__(self, alpha=0.5, beta=0.5):
        """Loss function for classification and selective regression
        
        Args:
            alpha (float): Weight for classification loss. Defaults to 0.5.
            beta (float): Weight for regression loss. Defaults to 0.5.
        """
        super().__init__()
        self.classification_loss = nn.CrossEntropyLoss()
        self.regression_loss = nn.MSELoss(reduction='none')  # Use 'none' to get per-sample losses
        self.alpha = alpha
        self.beta = beta
        
    def forward(self, pred, target):
        """Calculate loss with regression only for correct classifications
        
        Args:
            pred (tuple): (is_sky_logits, aqi_pred) 
            target (tuple): (is_sky_target, aqi_target)
        """
        is_sky_logits, aqi_pred = pred
        is_sky_target, aqi_target = target

        # Classification loss - applies to all samples
        classification_loss = self.classification_loss(is_sky_logits, is_sky_target)
        
        # Get predicted classes
        _, predicted_classes = torch.max(is_sky_logits, dim=1)
        
        # Create mask for correct classifications
        correct_mask = (predicted_classes == is_sky_target).float()
        
        # Per-sample regression losses
        per_sample_reg_loss = self.regression_loss(aqi_pred, aqi_target)
        
        # Only include regression loss for correctly classified samples
        masked_reg_loss = per_sample_reg_loss * correct_mask
        
        # Calculate average regression loss for correct samples
        correct_count = torch.sum(correct_mask)
        if correct_count > 0:
            regression_loss = torch.sum(masked_reg_loss) / correct_count
        else:
            regression_loss = torch.tensor(0.0, device=is_sky_logits.device)
        
        # Combined loss
        return self.alpha * classification_loss + self.beta * regression_loss

In [None]:
class MultiTaskCNN(nn.Module):
    def __init__(self):
        super().__init__()

        # Shared Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),  # [B, 3, H, W] → [B, 16, H, W]
            nn.BatchNorm2d(16),
            nn.GELU(),
            nn.MaxPool2d(2),  # ↓↓ resolution

            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),  # [B, 32, H/2, W/2]
            nn.BatchNorm2d(32),
            nn.GELU(),
            nn.MaxPool2d(2),  # ↓↓ resolution again
            
            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),  # [B, 32, H/2, W/2]
            nn.BatchNorm2d(32),
            nn.GELU(),

            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),  # [B, 64, H/4, W/4]
            nn.BatchNorm2d(64),
            nn.GELU(),
            nn.AdaptiveAvgPool2d((1, 1))  # → [B, 64, 1, 1]
        )

        # Classification Head
        self.classifier = nn.Sequential(
            nn.Flatten(),       # [B, 64]
            nn.Linear(64, 32),
            nn.GELU(),
            nn.Dropout(0.5),
            nn.Linear(32, 2)    # binary classification logits
        )

        # Regression Head
        self.regressor = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 8),
            nn.ReLU(),
            nn.Linear(8, 1)    # output AQI as a float
        )

    def forward(self, x):
        features = self.encoder(x)

        is_sky_logits = self.classifier(features)
        aqi_pred = self.regressor(features).squeeze(1)  # [B]

        return is_sky_logits, aqi_pred

In [9]:
class Trainer:
    def __init__(self, model, loss_fn, optimizer, device="cuda" if torch.cuda.is_available() else "cpu"):
        self.model = model.to(device)
        self.model = torch.nn.DataParallel(self.model)
        self.loss_fn = loss_fn 
        self.optimizer = optimizer
        self.device = device

    def _run_epoch(self, dataloader, train=True):
        mode = 'train' if train else 'eval'
        if train:
            self.model.train()
        else:
            self.model.eval()

        running_loss = 0.0
        total = 0
        correct = 0
        total_mae = 0.0

        loop = tqdm(dataloader, desc=f"{mode.capitalize()}ing", leave=False)
        for batch in loop:
            images = batch["image"].to(self.device)
            is_sky = batch["is_sky"].to(self.device).long()
            aqi = batch["aqi"].to(self.device).float()

            self.optimizer.zero_grad()

            with torch.set_grad_enabled(train):
                is_sky_logits, aqi_pred = self.model(images)
                loss = self.loss_fn((is_sky_logits, aqi_pred), (is_sky, aqi))

                if train:
                    loss.backward()
                    self.optimizer.step()

            running_loss += loss.item() * images.size(0)

            # Accuracy for classification
            _, preds = torch.max(is_sky_logits, dim=1)
            correct += torch.sum(preds == is_sky).item()
            total += is_sky.size(0)

            # MAE for regression
            total_mae += torch.abs(aqi_pred - aqi).sum().item()

        avg_loss = running_loss / total
        acc = correct / total
        mae = total_mae / total

        return avg_loss, acc, mae

    def fit(self, train_loader, val_loader, epochs=10, checkpoint=None):
        
        best_mae = 99999
        current_epoch = 0
        val_loss_list = []
        train_loss_list = []
        acc_list = []
        mae_list = []
        
        if not checkpoint == None:
            data = torch.load(checkpoint)
            current_epoch = data["epoch"]
            best_mae = data["mae"]
            best_acc = data["acc"]
            val_loss_list = data["val_loss_list"]
            train_loss_list = data["train_loss_list"]
            acc_list = data["acc_list"]
            mae_list = data["mae_list"]
            self.model.load_state_dict(data["model"])
            print(f"Loaded model at epoch {current_epoch}, mae: {best_mae}, acc: {best_acc}")
            
        for epoch in range(current_epoch, epochs):
            print(f"Epoch {epoch}/{epochs}")
            train_loss, train_acc, train_mae = self._run_epoch(train_loader, train=True)
            val_loss, val_acc, val_mae = self._run_epoch(val_loader, train=False)
            
            train_loss_list.append(train_loss)
            val_loss_list.append(val_loss)
            acc_list.append(val_acc)
            mae_list.append(val_mae)
            
            if val_mae < best_mae:
                best_mae = val_mae
                torch.save({
                    "model": self.model.state_dict(),
                    "epoch": epoch,
                    "mae": val_mae,
                    "acc": val_acc,
                    "train_loss_list": train_loss_list,
                    "val_loss_list": val_loss_list,
                    "acc_list": acc_list,
                    "mae_list": mae_list
                }, "best_model.pth")
                print(f"Saved best model with acc:{val_acc} and mae:{val_mae} in epoch {epoch}")

            print(f"Train   | Loss: {train_loss:.4f} | Acc: {train_acc:.4f} | MAE: {train_mae:.2f}")
            print(f"Val     | Loss: {val_loss:.4f} | Acc: {val_acc:.4f} | MAE: {val_mae:.2f}")


In [10]:
model = MultiTaskCNN()
loss_fn = MultiDomainLoss(alpha=0.1, beta=0.9)
epochs = 100
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-4,  weight_decay=1e-5)
trainer = Trainer(model=model, loss_fn=loss_fn, optimizer=optimizer)

train_dataset = AQIDataset(train=True, test=False)
train_loader = DataLoader(dataset=train_dataset, batch_size=16, num_workers=2)

val_dataset = AQIDataset(train=False, test=True)
val_loader = DataLoader(dataset=val_dataset, batch_size=16, num_workers=2)

In [11]:
if os.path.exists("best_model.pth"):
    checkpoint = "best_model.pth"
else:
    checkpoint = None        
trainer.fit(train_loader=train_loader, val_loader=val_loader, epochs=epochs, checkpoint=checkpoint)

Loaded model at epoch 49, mae: 17.7043509288551, acc: 0.9846813725490197
Epoch 49/100


                                                             

Train   | Loss: 965.1922 | Acc: 0.9766 | MAE: 20.04
Val     | Loss: 754.1719 | Acc: 0.9841 | MAE: 17.89
Epoch 50/100


                                                             

Train   | Loss: 951.5973 | Acc: 0.9772 | MAE: 19.75
Val     | Loss: 783.4356 | Acc: 0.9843 | MAE: 17.82
Epoch 51/100


                                                             

Train   | Loss: 925.9298 | Acc: 0.9777 | MAE: 19.44
Val     | Loss: 775.5825 | Acc: 0.9847 | MAE: 17.98
Epoch 52/100


                                                             

Saved best model with acc:0.983251633986928 and mae:17.38232477116429 in epoch 52
Train   | Loss: 925.4862 | Acc: 0.9780 | MAE: 19.45
Val     | Loss: 753.3863 | Acc: 0.9833 | MAE: 17.38
Epoch 53/100


                                                             

Train   | Loss: 936.8265 | Acc: 0.9786 | MAE: 19.29
Val     | Loss: 855.2358 | Acc: 0.9826 | MAE: 18.72
Epoch 54/100


                                                             

Saved best model with acc:0.9836601307189542 and mae:17.334309157203226 in epoch 54
Train   | Loss: 903.0923 | Acc: 0.9778 | MAE: 19.02
Val     | Loss: 721.6870 | Acc: 0.9837 | MAE: 17.33
Epoch 55/100


                                                             

Saved best model with acc:0.9846813725490197 and mae:16.718325493382473 in epoch 55
Train   | Loss: 893.9582 | Acc: 0.9793 | MAE: 18.84
Val     | Loss: 688.3141 | Acc: 0.9847 | MAE: 16.72
Epoch 56/100


                                                             

Train   | Loss: 888.6476 | Acc: 0.9793 | MAE: 18.76
Val     | Loss: 760.1140 | Acc: 0.9847 | MAE: 17.41
Epoch 57/100


                                                             

Train   | Loss: 863.3057 | Acc: 0.9778 | MAE: 18.50
Val     | Loss: 727.4784 | Acc: 0.9830 | MAE: 17.03
Epoch 58/100


                                                             

Train   | Loss: 866.0251 | Acc: 0.9796 | MAE: 18.46
Val     | Loss: 712.5389 | Acc: 0.9849 | MAE: 16.76
Epoch 59/100


                                                             

Train   | Loss: 852.7839 | Acc: 0.9804 | MAE: 18.27
Val     | Loss: 711.7836 | Acc: 0.9845 | MAE: 17.29
Epoch 60/100


                                                             

Train   | Loss: 861.7785 | Acc: 0.9795 | MAE: 18.39
Val     | Loss: 760.6953 | Acc: 0.9849 | MAE: 17.30
Epoch 61/100


                                                             

Train   | Loss: 812.0270 | Acc: 0.9810 | MAE: 17.87
Val     | Loss: 711.5145 | Acc: 0.9839 | MAE: 16.84
Epoch 62/100


                                                             

Train   | Loss: 835.7975 | Acc: 0.9810 | MAE: 18.10
Val     | Loss: 723.0460 | Acc: 0.9835 | MAE: 16.83
Epoch 63/100


                                                             

Train   | Loss: 796.5856 | Acc: 0.9810 | MAE: 17.67
Val     | Loss: 750.8258 | Acc: 0.9861 | MAE: 17.37
Epoch 64/100


                                                             

Train   | Loss: 778.6777 | Acc: 0.9806 | MAE: 17.48
Val     | Loss: 740.1805 | Acc: 0.9849 | MAE: 17.44
Epoch 65/100


                                                             

Train   | Loss: 773.1075 | Acc: 0.9817 | MAE: 17.43
Val     | Loss: 723.9314 | Acc: 0.9851 | MAE: 16.73
Epoch 66/100


                                                             

Train   | Loss: 767.0782 | Acc: 0.9820 | MAE: 17.36
Val     | Loss: 727.2627 | Acc: 0.9867 | MAE: 17.03
Epoch 67/100


                                                             

Train   | Loss: 788.6728 | Acc: 0.9818 | MAE: 17.59
Val     | Loss: 745.5406 | Acc: 0.9879 | MAE: 16.99
Epoch 68/100


                                                             

Saved best model with acc:0.9842728758169934 and mae:16.276266797695286 in epoch 68
Train   | Loss: 753.9733 | Acc: 0.9821 | MAE: 17.11
Val     | Loss: 667.6569 | Acc: 0.9843 | MAE: 16.28
Epoch 69/100


                                                             

Train   | Loss: 746.0155 | Acc: 0.9831 | MAE: 16.96
Val     | Loss: 722.4171 | Acc: 0.9861 | MAE: 16.71
Epoch 70/100


                                                             

Saved best model with acc:0.9865196078431373 and mae:16.061715720525754 in epoch 70
Train   | Loss: 734.2698 | Acc: 0.9822 | MAE: 16.96
Val     | Loss: 658.1039 | Acc: 0.9865 | MAE: 16.06
Epoch 71/100


                                                             

Saved best model with acc:0.9852941176470589 and mae:15.997233385354086 in epoch 71
Train   | Loss: 709.2355 | Acc: 0.9828 | MAE: 16.64
Val     | Loss: 646.8613 | Acc: 0.9853 | MAE: 16.00
Epoch 72/100


                                                             

Train   | Loss: 717.5248 | Acc: 0.9834 | MAE: 16.80
Val     | Loss: 725.3113 | Acc: 0.9853 | MAE: 17.18
Epoch 73/100


                                                             

Train   | Loss: 711.0149 | Acc: 0.9833 | MAE: 16.57
Val     | Loss: 684.9112 | Acc: 0.9861 | MAE: 16.21
Epoch 74/100


                                                             

Train   | Loss: 695.5401 | Acc: 0.9836 | MAE: 16.49
Val     | Loss: 790.7066 | Acc: 0.9851 | MAE: 17.82
Epoch 75/100


                                                             

Saved best model with acc:0.9842728758169934 and mae:15.64784127590703 in epoch 75
Train   | Loss: 695.4520 | Acc: 0.9837 | MAE: 16.43
Val     | Loss: 610.7669 | Acc: 0.9843 | MAE: 15.65
Epoch 76/100


                                                             

Train   | Loss: 708.0578 | Acc: 0.9837 | MAE: 16.49
Val     | Loss: 694.8877 | Acc: 0.9863 | MAE: 16.74
Epoch 77/100


                                                             

Train   | Loss: 669.7992 | Acc: 0.9840 | MAE: 16.20
Val     | Loss: 718.8951 | Acc: 0.9843 | MAE: 17.13
Epoch 78/100


                                                             

Train   | Loss: 749.9567 | Acc: 0.9828 | MAE: 16.88
Val     | Loss: 641.2758 | Acc: 0.9875 | MAE: 15.77
Epoch 79/100


                                                             

Train   | Loss: 668.2332 | Acc: 0.9847 | MAE: 16.02
Val     | Loss: 852.4315 | Acc: 0.9861 | MAE: 18.57
Epoch 80/100


                                                             

Train   | Loss: 678.9616 | Acc: 0.9843 | MAE: 16.21
Val     | Loss: 813.4359 | Acc: 0.9873 | MAE: 18.04
Epoch 81/100


                                                             

Train   | Loss: 640.4352 | Acc: 0.9849 | MAE: 15.73
Val     | Loss: 701.0600 | Acc: 0.9882 | MAE: 16.15
Epoch 82/100


                                                             

Train   | Loss: 649.5579 | Acc: 0.9845 | MAE: 15.87
Val     | Loss: 658.1823 | Acc: 0.9877 | MAE: 15.97
Epoch 83/100


                                                             

Train   | Loss: 635.7568 | Acc: 0.9849 | MAE: 15.65
Val     | Loss: 649.9577 | Acc: 0.9865 | MAE: 16.36
Epoch 84/100


                                                             

Train   | Loss: 621.6573 | Acc: 0.9850 | MAE: 15.50
Val     | Loss: 617.6470 | Acc: 0.9871 | MAE: 15.65
Epoch 85/100


                                                             

Train   | Loss: 626.6644 | Acc: 0.9849 | MAE: 15.54
Val     | Loss: 801.5025 | Acc: 0.9869 | MAE: 18.19
Epoch 86/100


                                                             

Train   | Loss: 629.5695 | Acc: 0.9847 | MAE: 15.59
Val     | Loss: 608.0889 | Acc: 0.9879 | MAE: 15.67
Epoch 87/100


                                                             

Train   | Loss: 632.8386 | Acc: 0.9851 | MAE: 15.58
Val     | Loss: 715.6976 | Acc: 0.9873 | MAE: 17.23
Epoch 88/100


                                                             

Train   | Loss: 594.5407 | Acc: 0.9857 | MAE: 15.14
Val     | Loss: 641.2262 | Acc: 0.9884 | MAE: 16.17
Epoch 89/100


                                                             

Train   | Loss: 606.5410 | Acc: 0.9848 | MAE: 15.31
Val     | Loss: 843.4293 | Acc: 0.9833 | MAE: 19.29
Epoch 90/100


                                                             

Train   | Loss: 589.8449 | Acc: 0.9855 | MAE: 15.10
Val     | Loss: 857.9162 | Acc: 0.9861 | MAE: 19.10
Epoch 91/100


                                                             

Saved best model with acc:0.9861111111111112 and mae:14.945576082647236 in epoch 91
Train   | Loss: 579.1869 | Acc: 0.9855 | MAE: 14.96
Val     | Loss: 570.6252 | Acc: 0.9861 | MAE: 14.95
Epoch 92/100


                                                             

Train   | Loss: 593.0354 | Acc: 0.9853 | MAE: 15.05
Val     | Loss: 815.8022 | Acc: 0.9843 | MAE: 18.43
Epoch 93/100


                                                             

Train   | Loss: 589.5914 | Acc: 0.9864 | MAE: 15.05
Val     | Loss: 622.3008 | Acc: 0.9882 | MAE: 15.16
Epoch 94/100


                                                             

Train   | Loss: 562.0030 | Acc: 0.9868 | MAE: 14.74
Val     | Loss: 559.2287 | Acc: 0.9867 | MAE: 15.03
Epoch 95/100


                                                             

Train   | Loss: 560.6890 | Acc: 0.9865 | MAE: 14.68
Val     | Loss: 838.3492 | Acc: 0.9841 | MAE: 19.21
Epoch 96/100


                                                             

Train   | Loss: 575.0537 | Acc: 0.9858 | MAE: 14.89
Val     | Loss: 671.6555 | Acc: 0.9867 | MAE: 16.85
Epoch 97/100


                                                             

Train   | Loss: 565.7303 | Acc: 0.9866 | MAE: 14.69
Val     | Loss: 749.7885 | Acc: 0.9861 | MAE: 17.58
Epoch 98/100


                                                             

Train   | Loss: 533.7784 | Acc: 0.9866 | MAE: 14.38
Val     | Loss: 865.1231 | Acc: 0.9855 | MAE: 19.41
Epoch 99/100


                                                             

Train   | Loss: 538.0236 | Acc: 0.9868 | MAE: 14.40
Val     | Loss: 568.9176 | Acc: 0.9869 | MAE: 15.08


