In [None]:
!pip install torch==1.5.1+cu101 torchvision==0.6.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html

In [None]:
import sys
import os
import torch
from torch import optim
from time import time
from tile2vec_model.datasets import TileTripletsDataset, GetBands, RandomFlipAndRotate
from tile2vec_model.datasets import ClipAndScale, ToFloatTensor, triplet_dataloader
from tile2vec_model.tilenet import make_tilenet
from tile2vec_model.training import prep_triplets, train_triplet_epoch

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
cuda = torch.cuda.is_available()
print('PyTorch is using GPU: {}'.format(cuda))

### Step 1: Set-up Data Loader

In [None]:
img_type = 'airbus' # change to correct image source
tile_dir = 'path_to_tiles_directory'
triplet_fp = 'path_to_triplets'
bands = 5
augment = False
batch_size = 16
shuffle = True
num_workers = 1
n_triplets = 2 * 59999

In [None]:
dataloader = triplet_dataloader(img_type, tile_dir, triplet_fp, bands=bands, augment=augment,
                                batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, 
                                n_triplets=n_triplets, pairs_only=False)
print('Dataloader set up complete.')

### Step 2: Set Up TileNet Model

In [None]:
# Model Params
in_channels = bands
z_dim = 512

In [None]:
TileNet = make_tilenet(in_channels=in_channels, z_dim=z_dim)
TileNet.train()
if cuda: TileNet.cuda()
print('TileNet set up complete.')

In [None]:
# Set-up the Optimizer
lr = 1e-3
optimizer = optim.Adam(TileNet.parameters(), lr=lr, betas=(0.5, 0.999))

### Step 3: Train the Model

In [None]:
# Training Params
epochs = 10
margin = 10
l2 = 0.01
print_every = 10000
save_models = False

In [None]:
t0 = time()
with open('tile2vec_model/training_output', 'w') as file:
    print('Begin training.................')
    for epoch in range(0, epochs):
        (avg_loss, avg_l_n, avg_l_d, avg_l_nd) = train_triplet_epoch(
            TileNet, cuda, dataloader, optimizer, epoch+1, margin=margin, l2=l2,
            print_every=print_every, t0=t0)