In [None]:
from google.colab import drive
drive.mount('/content/gdrive') 

In [3]:
import sys

sys.path.insert(0, r'/content/gdrive/My Drive/depth estimation')

In [4]:
import torch
import torch.nn as nn
from unet import ResNetUNet
from dataset import DepthMapDataset, DepthMapDataLoader
import matplotlib.pyplot as plt


In [5]:
dataset = DepthMapDataset(r'/content/gdrive/My Drive/depth-dataset/train')
dataloader = DepthMapDataLoader(dataset=dataset, batch_size=1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [None]:
model = ResNetUNet()
model = model.to(device)
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.0001)
optimizer.zero_grad()

criterion = torch.nn.MSELoss()

In [None]:
from tqdm.notebook import tqdm
import gc
from google.colab import files

def debug_gpu():
    # Debug out of memory bugs.
    tensor_list = []
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
                tensor_list.append(obj)
        except:
            pass
    print(f'Count of tensors = {len(tensor_list)}.')


average_losses = []

for epoch in range(5): 
  average_loss = 0

  for i_batch, sample_batched in enumerate(tqdm(dataloader)):
    optimizer.zero_grad()

    images, depth_maps, shapes = sample_batched['images'], sample_batched['depth_maps'], sample_batched['shapes']
    max_size = max(shape[0] * shape[1] for shape in shapes)

    if max_size < 3 * 10**6:
      images = torch.from_numpy(images).float().to(device)
      output = model(images)
      depth_maps = torch.from_numpy(depth_maps).float().to(device) / 1000

      total_loss = 0

      for pred, real, shape in zip(output, depth_maps, shapes): 
          loss = criterion(pred[0, :shape[0], :shape[1]], real[:shape[0], :shape[1]])
          loss.backward()
          optimizer.step()
          loss.detach()
          total_loss += float(loss)

          del loss
          del pred
          del real

      del depth_maps
      del output
      del images

      average_loss += total_loss
      torch.cuda.empty_cache()

  average_losses.append(average_loss)
  torch.save(model.state_dict(), 'checkpoint.pth')
  files.download('checkpoint.pth')
  print('Average Loss:', average_loss)

HBox(children=(FloatProgress(value=0.0, max=2479.0), HTML(value='')))




<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Average Loss: 38367.878700107336


HBox(children=(FloatProgress(value=0.0, max=2479.0), HTML(value='')))




<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Average Loss: 35275.35473874211


HBox(children=(FloatProgress(value=0.0, max=2479.0), HTML(value='')))




<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Average Loss: 29416.07173538953


HBox(children=(FloatProgress(value=0.0, max=2479.0), HTML(value='')))




<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Average Loss: 30693.55955261737


HBox(children=(FloatProgress(value=0.0, max=2479.0), HTML(value='')))




<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Average Loss: 26986.12796534598
