## Imports

In [65]:
import pandas as pd
import numpy as np
from pathlib import Path
from tqdm import tqdm

from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.functional as F

import torchvision.models as models
import torchvision.transforms as transforms

from torchmetrics import MeanSquaredError
from torchmetrics.functional import mean_squared_error

from torch.utils.data import Dataset, DataLoader

## Prepare Data

In [66]:
DATA_DIR = Path.cwd() / 'data'

#IMAGE_DIR = DATA_DIR / 'images'
IMAGE_DIR = DATA_DIR / 'sentinel-images'
IMAGE_DIR.mkdir(exist_ok=True, parents=True)

In [67]:
metadata = pd.read_csv(DATA_DIR / 'metadata.csv')
metadata.date = pd.to_datetime(metadata.date)
print(metadata.shape)
metadata.head()

(23570, 5)


Unnamed: 0,uid,latitude,longitude,date,split
0,aabm,39.080319,-86.430867,2018-05-14,train
1,aabn,36.5597,-121.51,2016-08-31,test
2,aacd,35.875083,-78.878434,2020-11-19,train
3,aaee,35.487,-79.062133,2016-08-24,train
4,aaff,38.049471,-99.827001,2019-07-23,train


In [68]:
torch.manual_seed(42)

torch.backends.cudnn.benchmark = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose(
    [
        transforms.Resize((480, 480), interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)

In [79]:
class SatelliteDataset(Dataset):
    def __init__(self, images_list, transform=None):
        self.images_list = images_list
        self.transform = transform
    
    def __len__(self):
        return len(self.images_list)
    
    def __getitem__(self, index):
        image_pth = IMAGE_DIR / f"{self.images_list[index]}.png"
        image = Image.open(image_pth)
        
        if self.transform is not None:
            image = self.transform(image)
            
        return image
    
def get_loader(images_list, transform, batch_size=32, shuffle=True, num_workers=0, pin_memory=True):
    
    dataset = SatelliteDataset(images_list, transform)
    
    loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )
    
    return loader, dataset

## Create Model

In [90]:
class CNN(nn.Module):
    def __init__(self, num_classes, dropout=0.2):
        super(CNN, self).__init__()
        self.efficientnet = models.efficientnet_v2_l(weights='DEFAULT')
        
        in_features = self.efficientnet.classifier[1].in_features
        
        self.efficientnet = nn.Sequential(*list(self.efficientnet.children())[:-1])
        self.fc = nn.Sequential(
            nn.Linear(in_features, in_features//2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(in_features//2, num_classes)
        )        
        
    def forward(self, x, target=None):
        features = self.efficientnet(x)
        features = features.view(features.size(0), -1)
        features = self.fc(features)
        
        loss = None
        
        if target is not None:
            preds = torch.argmax(F.softmax(features, dim=1))
            loss = mean_squared_error(preds, target, squared=False)
        
        return features, loss
    

## Train Model

In [95]:
num_classes = 5
learning_rate = 3e-4
dropout = 0.3

batch_size = 8
num_workers = 0

num_epochs = 5

In [77]:
train_list = metadata[metadata.split == 'train']['uid'].tolist()
test_list = metadata[metadata.split == 'test']['uid'].tolist()

In [80]:
train_loader, train_dataset = get_loader(train_list, transform, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_loader, test_dataset = get_loader(test_list, transform, batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [91]:
model = CNN(num_classes).to(device)

criterion = MeanSquaredError(squared=False)

optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

In [94]:
model.train()

for epoch in range(num_epochs):
    pbar = tqdm(train_loader, total=len(train_loader))
    for idx, imgs in enumerate(pbar):
        imgs = imgs.to(device)
        
        targets = None
        
        outputs = model(imgs)
        
        loss = criterion(outputs, targets)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

  0%|          | 0/2133 [00:00<?, ?it/s]Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/keenansamway/miniconda3/envs/algae/lib/python3.9/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/Users/keenansamway/miniconda3/envs/algae/lib/python3.9/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'SatelliteDataset' on <module '__main__' (built-in)>
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/keenansamway/miniconda3/envs/algae/lib/python3.9/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/Users/keenansamway/miniconda3/envs/algae/lib/python3.9/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'SatelliteDataset' on <module '__main__' (built-in)>
  0%|   

RuntimeError: DataLoader worker (pid(s) 83762) exited unexpectedly

## Evaluate Model

## Create Predictions

In [None]:
train_labels = pd.read_csv(DATA_DIR / 'train_labels.csv')
print(train_labels.shape)
train_labels.head()

In [None]:
submission_format = pd.read_csv(DATA_DIR / 'submission_format.csv', index_col='uid')
print(submission_format.shape)
submission_format.head()