In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
import os
import torch
from torch import optim
from time import time

In [3]:
sys.path.append('../')

In [4]:
from src.datasets import TileTripletsDataset, GetBands, RandomFlipAndRotate, ClipAndScale, ToFloatTensor, triplet_dataloader
from src.tilenet import make_tilenet

In [5]:
from src.training import prep_triplets, train_triplet_epoch

# Step 1. Download triplets from bucket

Using the download link, unzip triplets into the directory /tile2vec/data/triplets.

# Step 2. Set up dataloader

In [6]:
# Environment stuff
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
cuda = torch.cuda.is_available()

Set up the dataloader for training.

In [7]:
# Change these arguments to match your directory and desired parameters
img_type = 'naip'
tile_dir = '../data/tiles-pm25/'
bands = 4
augment = True
batch_size = 50
shuffle = True
num_workers = 4
n_triplets = 19

In [8]:
dataloader = triplet_dataloader(img_type, tile_dir, bands=bands, augment=augment,
                                batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, 
                                n_triplets=n_triplets, pairs_only=True)
print('Dataloader set up complete.')

Dataloader set up complete.


# Step 3. Set up TileNet

In [9]:
in_channels = bands
z_dim = 512

In [32]:
TileNet = make_tilenet(in_channels=in_channels, z_dim=z_dim)
TileNet.train()
if cuda: TileNet.cuda()
print('TileNet set up complete.')

TileNet set up complete.


Set up optimizer.

In [33]:
lr = 1e-3
optimizer = optim.Adam(TileNet.parameters(), lr=lr, betas=(0.5, 0.999))

In [37]:
next(TileNet.parameters()).shape

torch.Size([64, 4, 3, 3])

# Step 4. Train model!

In [111]:
epochs = 5
margin = 10
l2 = 0.01
print_every = 1
save_models = False

Define the directory for saving models.

In [112]:
model_dir = '../models/pm25/'
if not os.path.exists(model_dir): os.makedirs(model_dir)

In [113]:
t0 = time()
results_fn = 'results_file'

with open(results_fn, 'w') as file:

    print('Begin training.................')
    for epoch in range(0, epochs):
        (avg_loss, avg_l_n, avg_l_d, avg_l_nd) = train_triplet_epoch(
            TileNet, cuda, dataloader, optimizer, epoch+1, margin=margin, l2=l2,
            print_every=print_every, t0=t0)

Begin training.................
Finished epoch 1: 15.678s
  Average loss: 8.5045
  Average l_n: 3.4348
  Average l_d: -6.5597
  Average l_nd: -3.1249

Finished epoch 2: 31.461s
  Average loss: 7.2854
  Average l_n: 4.8844
  Average l_d: -13.6047
  Average l_nd: -8.7203

Finished epoch 3: 48.567s
  Average loss: 7.2188
  Average l_n: 5.3161
  Average l_d: -12.7226
  Average l_nd: -7.4065

Finished epoch 4: 83.274s
  Average loss: 3.7725
  Average l_n: 3.7728
  Average l_d: -17.0414
  Average l_nd: -13.2686

Finished epoch 5: 108.752s
  Average loss: 5.7249
  Average l_n: 2.8915
  Average l_d: -12.3243
  Average l_nd: -9.4329



In [114]:
# Save model after last epoch
if save_models:
    model_fn = os.path.join(model_dir, 'TileNet_epoch50.ckpt')
    torch.save(TileNet.state_dict(), model_fn)

# Test Model

In [116]:
TileNet

TileNet(
  (conv1): Conv2d(4, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): ResidualBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): ResidualBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, 