# Training Demo
This notebook shows training of a late fusion and a robust transfuser model. Make sure to run this from the root of the project directory. Please activate the environment provided by the requirements.txt file.

The entire code of the folders transfuser, leaderboard, latefusion, scenario_runner and tools is adapted from the transfuser Github repository: https://github.com/autonomousvision/transfuser.

This notebook is adapted from transfuser/train.py.

## Setup

In [None]:
import json
import os
from tqdm import tqdm
import random 

import numpy as np
import matplotlib.pyplot as plt
from IPython.display import clear_output, display
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
torch.backends.cudnn.benchmark = True

from transfuser.config import GlobalConfig
from transfuser.model import TransFuser
from transfuser.data import CARLA_Data

%matplotlib inline
%load_ext autoreload
%autoreload 2

# Constants

In [None]:
BATCH_SIZE = 1
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
LEARNING_RATE = 1e-4
LOG_DIR = 'log'
NUM_EPOCHS = 1

If you want to train a robust model set `use_robust` to True

In [None]:
USE_ROBUST = True

if USE_ROBUST:
    SENSOR_FAILURES = ['none', 'rgb', 'lidar']
else:
    SENSOR_FAILURES = ['none']

# Config

In [None]:
# Config
config = GlobalConfig()
config.use_pseudolidar = False

# Data
train_set = CARLA_Data(root=config.train_data, config=config)
val_set = CARLA_Data(root=config.val_data, config=config)

dataloader_train = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
dataloader_val = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

# Model
model = TransFuser(config, DEVICE, USE_ROBUST)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)

model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])

## Inspect data

In [None]:
batch = next(iter(dataloader_train))

# Input
fronts_in = batch['fronts']
lefts_in = batch['lefts']
rights_in = batch['rights']
rears_in = batch['rears']
lidars_in = batch['lidars']
 
# Labels
command = batch['command']
gt_velocity = batch['velocity']
gt_steer = batch['steer']
gt_throttle = batch['throttle']
gt_brake = batch['brake']

In [None]:
fig = plt.figure(figsize=(20,15))
ax1 = fig.add_subplot(1, 2, 1)
ax2 = fig.add_subplot(1, 2, 2)

# Plot formatting
ax1.axis('off')
ax2.axis('off')
ax1.title.set_text('RGB')
ax2.title.set_text('LiDAR')
ax1.title.set_font_properties({'size': 22})
ax2.title.set_font_properties({'size': 22})

ax1.imshow(fronts_in[0].squeeze().permute((1,2,0)))
ax2.imshow(np.max(lidars_in[0].squeeze().numpy(), axis=0))

## Labels

In [None]:
print(f'Command: {command.numpy()[0]}')
print(f'Velocity: {gt_velocity.numpy()[0]}')
print(f'Steering Angle: {gt_steer.numpy()[0]}')
print(f'Throttle: {gt_throttle.numpy()[0]}')
print(f'Brake: {gt_brake.numpy()[0]}')

## Training

In [None]:
def save(self, train_loss, val_loss, bestval, cur_epoch, cur_iter):

    save_best = False
    if val_loss[-1] <= bestval:
        bestval = val_loss[-1]
        save_best = True
    
    # Create a dictionary of all data to save
    log_table = {
        'epoch': cur_epoch,
        'iter': cur_iter,
        'bestval': bestval,
        'train_loss': train_loss,
        'val_loss': val_loss,
    }

    # Save ckpt for every epoch
    torch.save(model.state_dict(), os.path.join(LOG_DIR, 'model_%d.pth'%self.cur_epoch))

    # Save the recent model/optimizer states
    torch.save(model.state_dict(), os.path.join(LOG_DIR, 'model.pth'))
    torch.save(optimizer.state_dict(), os.path.join(LOG_DIR, 'recent_optim.pth'))

    # Log other data corresponding to the recent model
    with open(os.path.join(LOG_DIR, 'recent.log'), 'w') as f:
        f.write(json.dumps(log_table))

    tqdm.write('====== Saved recent model ======>')
    
    if save_best:
        torch.save(model.state_dict(), os.path.join(LOG_DIR, 'best_model.pth'))
        torch.save(optimizer.state_dict(), os.path.join(LOG_DIR, 'best_optim.pth'))
        tqdm.write('====== Overwrote best model ======>')

In [None]:
train_loss = []
val_loss = []
cur_iter = 0
bestval = 1e10
for epoch in range(NUM_EPOCHS): 
	####---------------- TRAINING -----------------####
	for data in tqdm(dataloader_train):
		loss_epoch = 0.
		num_batches = 0
		model.train()
		
		# create batch and move to GPU
		fronts_in = data['fronts']
		lefts_in = data['lefts']
		rights_in = data['rights']
		rears_in = data['rears']
		lidars_in = data['lidars']
		fronts = []
		lefts = []
		rights = []
		rears = []
		lidars = []
		rgb_corrupted = False
		lidar_corrupted = False
		current_batch_size = fronts_in[0].shape[0]
		for i in range(config.seq_len):
			current_sensor_failure_used = random.choice(SENSOR_FAILURES)
			gt_front = torch.zeros(current_batch_size, dtype=torch.float32)
			if current_sensor_failure_used == 'rgb' :
				possible_fronts = [fronts_in[i], torch.zeros_like(fronts_in[i])]
				front = torch.zeros_like(fronts_in[i])
				if torch.max(front) == 0:
					rgb_corrupted = True
					gt_front = torch.ones(current_batch_size, dtype=torch.float32)
				fronts.append(front.to(DEVICE, dtype=torch.float32))
			else:
				fronts.append(fronts_in[i].to(DEVICE, dtype=torch.float32))
			if not config.ignore_sides:
				lefts.append(lefts_in[i].to(DEVICE, dtype=torch.float32))
				rights.append(rights_in[i].to(DEVICE, dtype=torch.float32))
			if not config.ignore_rear:
				rears.append(rears_in[i].to(DEVICE, dtype=torch.float32))
			gt_lidar = torch.zeros(current_batch_size, dtype=torch.float32)
			if current_sensor_failure_used == 'lidar' and not rgb_corrupted:
				possible_lidars = [lidars_in[i], torch.zeros_like(lidars_in[i])]
				lidar = torch.zeros_like(lidars_in[i])
				lidars.append(lidar.to(DEVICE, dtype=torch.float32))

				if lidar.mean() == 0:
					lidar_corrupted = True
					gt_lidar = torch.ones(current_batch_size, dtype=torch.float32)
			else:
				lidars.append(lidars_in[i].to(DEVICE, dtype=torch.float32))

		gt_both_working = torch.zeros(current_batch_size, dtype=torch.float32) if lidar_corrupted or rgb_corrupted else torch.ones(current_batch_size, dtype=torch.float32)

		# driving labels
		command = data['command'].to(DEVICE)
		gt_velocity = data['velocity'].to(DEVICE, dtype=torch.float32)
		gt_steer = data['steer'].to(DEVICE, dtype=torch.float32)
		gt_throttle = data['throttle'].to(DEVICE, dtype=torch.float32)
		gt_brake = data['brake'].to(DEVICE, dtype=torch.float32)

		# target point
		target_point = torch.stack(data['target_point'], dim=1).to(DEVICE, dtype=torch.float32)
		
		if USE_ROBUST:
			pred_wp, logits = model(fronts+lefts+rights+rears, lidars, target_point, gt_velocity)
		else:
			pred_wp = model(fronts+lefts+rights+rears, lidars, target_point, gt_velocity)
		


		if USE_ROBUST:
			gt_waypoints = [torch.stack(data['waypoints'][i], dim=1).to(DEVICE, dtype=torch.float32) for i in range(config.seq_len, len(data['waypoints']))]
			gt_waypoints = torch.stack(gt_waypoints, dim=1).to(DEVICE, dtype=torch.float32)
			gt_classification = torch.stack([gt_front, gt_lidar, gt_both_working], dim=1).to(DEVICE, dtype=torch.float32)
			loss = F.l1_loss(pred_wp, gt_waypoints, reduction='none').mean()
			classification_loss = F.cross_entropy(logits, gt_classification)
			total_loss = sum([loss, classification_loss])
			total_loss.backward()
			loss_epoch += float(total_loss.item())

		else:
			gt_waypoints = [torch.stack(data['waypoints'][i], dim=1).to(DEVICE, dtype=torch.float32) for i in range(config.seq_len, len(data['waypoints']))]
			gt_waypoints = torch.stack(gt_waypoints, dim=1).to(DEVICE, dtype=torch.float32)
			total_loss = F.l1_loss(pred_wp, gt_waypoints, reduction='none').mean()
			total_loss.backward()
			loss_epoch += float(total_loss.item())


		num_batches += 1
		optimizer.step()

		cur_iter += 1


	loss_epoch = loss_epoch / num_batches
	train_loss.append(loss_epoch)


	####---------------- VALIDATION -----------------####
	model.eval()

	with torch.no_grad():	
		num_batches = 0
		wp_epoch = 0.
		classification_loss_epoch = 0.

		# Validation loop
		for batch_num, data in enumerate(tqdm(dataloader_val), 0):
			
			# create batch and move to GPU
			fronts_in = data['fronts']
			lefts_in = data['lefts']
			rights_in = data['rights']
			rears_in = data['rears']
			lidars_in = data['lidars']
			fronts = []
			lefts = []
			rights = []
			rears = []
			lidars = []
			current_batch_size = fronts_in[0].shape[0]
			for i in range(config.seq_len):
				fronts.append(fronts_in[i].to(DEVICE, dtype=torch.float32))
				if not config.ignore_sides:
					lefts.append(lefts_in[i].to(DEVICE, dtype=torch.float32))
					rights.append(rights_in[i].to(DEVICE, dtype=torch.float32))
				if not config.ignore_rear:
					rears.append(rears_in[i].to(DEVICE, dtype=torch.float32))
				lidars.append(lidars_in[i].to(DEVICE, dtype=torch.float32))

			# driving labels
			command = data['command'].to(DEVICE)
			gt_velocity = data['velocity'].to(DEVICE, dtype=torch.float32)
			gt_steer = data['steer'].to(DEVICE, dtype=torch.float32)
			gt_throttle = data['throttle'].to(DEVICE, dtype=torch.float32)
			gt_brake = data['brake'].to(DEVICE, dtype=torch.float32)
			gt_lidar = torch.zeros(current_batch_size, dtype=torch.float32)
			gt_front = torch.zeros(current_batch_size, dtype=torch.float32)
			gt_both_working = torch.ones(current_batch_size, dtype=torch.float32)

			# target point
			target_point = torch.stack(data['target_point'], dim=1).to(DEVICE, dtype=torch.float32)

			if USE_ROBUST:
				pred_wp, logits = model(fronts+lefts+rights+rears, lidars, target_point, gt_velocity)
			else:
				pred_wp = model(fronts+lefts+rights+rears, lidars, target_point, gt_velocity, train=False)

			if USE_ROBUST:
				gt_waypoints = [torch.stack(data['waypoints'][i], dim=1).to(DEVICE, dtype=torch.float32) for i in range(config.seq_len, len(data['waypoints']))]
				gt_waypoints = torch.stack(gt_waypoints, dim=1).to(DEVICE, dtype=torch.float32)
				gt_classification = torch.stack([gt_front, gt_lidar, gt_both_working], dim=1).to(args.device, dtype=torch.float32)
				loss = F.l1_loss(pred_wp, gt_waypoints, reduction='none').mean()
				classification_loss = F.cross_entropy(logits, gt_classification)
				total_loss = sum([loss, classification_loss])
				wp_epoch += float(total_loss.item())
				classification_loss_epoch += float(classification_loss.item())
			# wp_loss_epoch += float(total_loss.item())
			# classification_loss_epoch += float(total_loss.item())
			else:
				gt_waypoints = [torch.stack(data['waypoints'][i], dim=1).to(DEVICE, dtype=torch.float32) for i in range(config.seq_len, len(data['waypoints']))]
				gt_waypoints = torch.stack(gt_waypoints, dim=1).to(DEVICE, dtype=torch.float32)
				wp_epoch += float(F.l1_loss(pred_wp, gt_waypoints, reduction='none').mean())

			num_batches += 1
				
		wp_loss = wp_epoch / float(num_batches)
		classification_loss_final = classification_loss_epoch / float(num_batches)
		tqdm.write(f'Epoch {epoch:03d}, Batch {batch_num:03d}:' + f' Wp: {wp_loss:3.3f}')

		
		val_loss.append(wp_loss)
	

	save(train_loss, val_loss, bestval, epoch, cur_iter)