In [1]:
import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import datasets, models
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import classification_report
from tqdm import tqdm

In [None]:
DATA_DIR = './Aerial_Landscapes'
BATCH_SIZE = 32
NUM_CLASSES = 15
EPOCHS = 10
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
])
test_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

In [None]:
def get_model():
    model = models.resnet50(pretrained=True)
    model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)
    return model.to(DEVICE)

In [None]:
def train_model(model, train_loader, test_loader):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    for epoch in range(EPOCHS):
        model.train()
        total_loss = 0
        for imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            loss = criterion(model(imgs), labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Loss: {total_loss / len(train_loader):.4f}")
        evaluate_model(model, test_loader)

In [None]:
def evaluate_model(model, loader):
    model.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for imgs, labels in loader:
            imgs = imgs.to(DEVICE)
            outputs = model(imgs)
            _, preds = torch.max(outputs, 1)
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())
    print(classification_report(y_true, y_pred, zero_division=0))

In [None]:
def main():
    dataset = datasets.ImageFolder(DATA_DIR, transform=train_transforms)
    train_len = int(0.8 * len(dataset))
    test_len = len(dataset) - train_len
    train_data, test_data = random_split(dataset, [train_len, test_len])

    train_data.dataset.transform = train_transforms
    test_data.dataset.transform = test_transforms

    train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=BATCH_SIZE)

    model = get_model()
    train_model(model, train_loader, test_loader)

if __name__ == '__main__':
    main()


Epoch 1/10: 100%|████████████████████████████████████████████████████████████████████| 300/300 [24:11<00:00,  4.84s/it]


Loss: 0.3907
              precision    recall  f1-score   support

           0       0.96      0.98      0.97       167
           1       0.75      1.00      0.85       173
           2       0.96      0.96      0.96       135
           3       0.97      0.74      0.84       158
           4       0.91      0.96      0.93       158
           5       0.96      0.97      0.97       160
           6       0.99      0.99      0.99       143
           7       0.94      0.95      0.95       159
           8       0.99      0.92      0.95       156
           9       0.93      0.96      0.95       172
          10       0.97      0.98      0.97       147
          11       0.99      0.88      0.93       167
          12       0.98      0.78      0.87       161
          13       0.88      0.97      0.92       158
          14       0.95      0.97      0.96       186

    accuracy                           0.93      2400
   macro avg       0.94      0.93      0.93      2400
weighted avg 

Epoch 2/10: 100%|████████████████████████████████████████████████████████████████████| 300/300 [22:03<00:00,  4.41s/it]


Loss: 0.1161
              precision    recall  f1-score   support

           0       0.98      0.99      0.99       167
           1       0.90      0.99      0.95       173
           2       1.00      0.98      0.99       135
           3       0.97      0.96      0.97       158
           4       0.96      0.99      0.97       158
           5       0.98      0.93      0.96       160
           6       0.93      0.97      0.95       143
           7       0.99      0.92      0.95       159
           8       0.98      0.96      0.97       156
           9       0.98      0.97      0.98       172
          10       0.99      0.97      0.98       147
          11       0.96      0.98      0.97       167
          12       0.96      0.98      0.97       161
          13       0.99      0.99      0.99       158
          14       0.97      0.95      0.96       186

    accuracy                           0.97      2400
   macro avg       0.97      0.97      0.97      2400
weighted avg 

Epoch 3/10: 100%|████████████████████████████████████████████████████████████████████| 300/300 [22:24<00:00,  4.48s/it]


Loss: 0.0704
              precision    recall  f1-score   support

           0       0.98      0.96      0.97       167
           1       0.97      0.94      0.96       173
           2       0.88      0.99      0.93       135
           3       0.99      0.94      0.96       158
           4       0.96      0.96      0.96       158
           5       0.97      0.94      0.96       160
           6       0.94      0.98      0.96       143
           7       0.95      0.97      0.96       159
           8       0.92      0.98      0.95       156
           9       0.95      0.97      0.96       172
          10       0.97      0.97      0.97       147
          11       0.99      0.94      0.97       167
          12       0.95      0.98      0.96       161
          13       1.00      0.97      0.98       158
          14       0.97      0.91      0.94       186

    accuracy                           0.96      2400
   macro avg       0.96      0.96      0.96      2400
weighted avg 

Epoch 4/10: 100%|████████████████████████████████████████████████████████████████████| 300/300 [21:42<00:00,  4.34s/it]


Loss: 0.0684
              precision    recall  f1-score   support

           0       0.96      0.98      0.97       167
           1       0.96      0.97      0.97       173
           2       0.99      0.98      0.99       135
           3       0.89      0.99      0.93       158
           4       0.96      0.97      0.97       158
           5       0.99      0.94      0.97       160
           6       0.96      0.98      0.97       143
           7       0.97      0.98      0.97       159
           8       0.99      0.94      0.96       156
           9       0.95      0.97      0.96       172
          10       0.99      0.98      0.98       147
          11       0.99      0.96      0.98       167
          12       1.00      0.91      0.95       161
          13       0.99      0.99      0.99       158
          14       0.94      0.97      0.96       186

    accuracy                           0.97      2400
   macro avg       0.97      0.97      0.97      2400
weighted avg 

Epoch 5/10: 100%|████████████████████████████████████████████████████████████████████| 300/300 [23:35<00:00,  4.72s/it]


Loss: 0.0583
              precision    recall  f1-score   support

           0       0.99      0.94      0.97       167
           1       0.97      0.96      0.96       173
           2       0.97      0.98      0.97       135
           3       0.95      0.96      0.96       158
           4       0.95      0.97      0.96       158
           5       0.98      0.95      0.97       160
           6       0.99      0.96      0.97       143
           7       0.96      0.99      0.97       159
           8       0.97      0.96      0.96       156
           9       0.95      0.93      0.94       172
          10       0.99      0.99      0.99       147
          11       0.99      0.93      0.96       167
          12       0.96      0.98      0.97       161
          13       0.95      0.96      0.96       158
          14       0.87      0.97      0.92       186

    accuracy                           0.96      2400
   macro avg       0.96      0.96      0.96      2400
weighted avg 

Epoch 6/10: 100%|████████████████████████████████████████████████████████████████████| 300/300 [24:22<00:00,  4.88s/it]


Loss: 0.0550
              precision    recall  f1-score   support

           0       0.92      1.00      0.96       167
           1       0.95      0.87      0.91       173
           2       0.96      0.99      0.97       135
           3       0.91      0.97      0.94       158
           4       0.96      0.96      0.96       158
           5       0.97      0.91      0.94       160
           6       0.95      0.98      0.96       143
           7       0.79      0.98      0.88       159
           8       0.96      0.97      0.96       156
           9       0.94      0.98      0.96       172
          10       0.99      0.99      0.99       147
          11       0.99      0.96      0.98       167
          12       1.00      0.72      0.84       161
          13       0.98      0.99      0.99       158
          14       0.99      0.95      0.97       186

    accuracy                           0.95      2400
   macro avg       0.95      0.95      0.95      2400
weighted avg 

Epoch 7/10: 100%|████████████████████████████████████████████████████████████████████| 300/300 [23:19<00:00,  4.66s/it]


Loss: 0.0479
              precision    recall  f1-score   support

           0       0.99      0.98      0.98       167
           1       0.98      0.94      0.96       173
           2       0.97      0.99      0.98       135
           3       0.97      0.94      0.96       158
           4       0.97      0.96      0.97       158
           5       0.98      0.97      0.98       160
           6       0.98      0.99      0.98       143
           7       0.96      0.92      0.94       159
           8       0.97      0.96      0.96       156
           9       0.94      0.97      0.96       172
          10       0.99      0.97      0.98       147
          11       0.99      0.97      0.98       167
          12       0.87      0.96      0.91       161
          13       0.99      0.99      0.99       158
          14       0.95      0.96      0.95       186

    accuracy                           0.97      2400
   macro avg       0.97      0.97      0.97      2400
weighted avg 

Epoch 8/10: 100%|████████████████████████████████████████████████████████████████████| 300/300 [24:42<00:00,  4.94s/it]


Loss: 0.0269
              precision    recall  f1-score   support

           0       0.97      0.99      0.98       167
           1       0.93      0.98      0.96       173
           2       0.90      0.99      0.94       135
           3       0.96      0.94      0.95       158
           4       0.96      0.99      0.97       158
           5       0.98      0.97      0.98       160
           6       0.99      0.95      0.97       143
           7       1.00      0.96      0.98       159
           8       0.98      0.93      0.95       156
           9       0.97      0.97      0.97       172
          10       0.98      0.99      0.99       147
          11       0.99      0.95      0.97       167
          12       0.98      0.99      0.98       161
          13       0.99      0.99      0.99       158
          14       0.98      0.97      0.97       186

    accuracy                           0.97      2400
   macro avg       0.97      0.97      0.97      2400
weighted avg 

Epoch 9/10: 100%|████████████████████████████████████████████████████████████████████| 300/300 [23:03<00:00,  4.61s/it]


Loss: 0.0375
              precision    recall  f1-score   support

           0       0.97      0.99      0.98       167
           1       0.96      0.97      0.97       173
           2       0.93      1.00      0.96       135
           3       0.98      0.93      0.95       158
           4       0.94      0.98      0.96       158
           5       0.99      0.97      0.98       160
           6       0.98      0.98      0.98       143
           7       0.99      0.95      0.97       159
           8       0.96      0.97      0.97       156
           9       0.95      0.94      0.95       172
          10       0.99      0.99      0.99       147
          11       0.99      0.94      0.96       167
          12       0.96      0.99      0.98       161
          13       0.96      0.99      0.98       158
          14       0.97      0.92      0.94       186

    accuracy                           0.97      2400
   macro avg       0.97      0.97      0.97      2400
weighted avg 

Epoch 10/10: 100%|███████████████████████████████████████████████████████████████████| 300/300 [21:35<00:00,  4.32s/it]


Loss: 0.0343
              precision    recall  f1-score   support

           0       0.98      0.99      0.98       167
           1       0.98      0.97      0.97       173
           2       0.99      0.96      0.97       135
           3       0.97      0.95      0.96       158
           4       0.93      0.97      0.95       158
           5       0.96      0.99      0.98       160
           6       0.99      0.94      0.97       143
           7       0.97      0.97      0.97       159
           8       0.96      0.98      0.97       156
           9       0.96      0.97      0.96       172
          10       1.00      0.99      1.00       147
          11       0.98      0.96      0.97       167
          12       0.97      0.99      0.98       161
          13       0.99      0.99      0.99       158
          14       0.97      0.95      0.96       186

    accuracy                           0.97      2400
   macro avg       0.97      0.97      0.97      2400
weighted avg 