# Theia

Gather intelligence from satellite data.

In [None]:
import json
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
import torch

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision

## Custom code

Check out the [intelligenerator organization](https://github.com/intelligenerator) on Github for more information.

In [None]:
from unet_dataset import SatelliteImageDataset
from unet import UNet

import src.data as data_utils 

## Config

In [None]:
BATCH_SIZE = 1
EPOCHS = 10
MODEL_PATH = './model/'

## Data Loading

In [None]:
dataset = SatelliteImageDataset(images_dir='data/images', targets_dir='data/targets', transform=data_utils.transforms)

In [None]:
dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, collate_fn=data_utils.collate
)

### Data Visualization

In [None]:
def show_img(img, transpose=True):
    # img = img / 2 + 0.5 # unnormalize
    npimg = img.numpy()
    if transpose:
        npimg = npimg.transpose(1, 2, 0)
    plt.imshow(npimg)
    plt.show()

In [None]:
dataiter = iter(dataloader)
imgs, labels = next(dataiter)

y = labels[0]
print(y.shape)
print(y.unique())

show_img(imgs[0])
show_img(y.squeeze(), transpose=False)

## Training

In [None]:
dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device "%s" for training' % dev)

In [None]:
net = UNet(in_channels=3, out_channels=1, padding=True).to(dev)

In [None]:
optimizer = optim.Adam(net.parameters(), lr=0.0001)
criterion = nn.BCEWithLogitsLoss()
loss_history = []

In [None]:
for epoch in range(EPOCHS):
    running_loss = 0.0
    
    for i, (imgs, labels) in enumerate(dataloader):        
        imgs = imgs.to(dev)
        labels = labels.to(dev)
        
        optimizer.zero_grad()
        
        y = net(imgs)
        
        loss = criterion(y, labels)
        loss.backward()
        
        optimizer.step()
        
        running_loss += loss.item()
        
        if i % 50 == 49:
            running_loss /= 50
            
            print('[%.3d / %.3d] Loss: %.9f' % (epoch, i, running_loss))
            loss_history.append(running_loss)
            
            running_loss = 0
        
    torch.save(net.state_dict(), MODEL_PATH + 'theia.pth')
    with open(MODEL_PATH + 'history.json', 'w+') as f:
        json.dump(loss_history, f)