In [1]:
import os
import torch
import torch_tensorrt
import torch.nn as nn
import torch.optim as optim
from dataset.video_dataset import VideoDataset
from models.conv_autoencoder import ConvAutoencoder
from models.conv_shifter import ConvShifter
from models.u_net import UNetAutoencoder
from models.FSRCNN import FSRCNN
from torchvision.transforms import ToTensor, Resize, Normalize
import matplotlib.pyplot as plt
import time
from torch.cuda.amp import autocast

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f'USING DEVICE: {device}')



USING DEVICE: cuda


In [2]:
START_SEC = 620
NAME = 'unet'
MODEL_ITER = 44_734
MODEL_DIR = os.path.expanduser('~/Desktop/Image Shifter/vision_progress')
SAMPLE_VIDEO_PATH = os.path.expanduser('~/Desktop/Image Shifter/dataset/video1.mp4')

In [3]:
video_dataset = VideoDataset(video_path=SAMPLE_VIDEO_PATH,
                             batch_len_sec=2/30,
                             start_sec=START_SEC,
                             device=device,
                            # transform=torch.nn.Sequential(
                                # Resize((img_size)),  # Resize frames to the target size
                                # Normalize((0.5,), (0.5,), inplace=True)  # Normalize to [-1, 1] range
                            # )
                            )

model = UNetAutoencoder(name=NAME, model_dir=MODEL_DIR).to(device)

model.load(MODEL_ITER)
# model.quantize()
model.eval()  # Set model to evaluation mode

model = torch.compile(model, backend='tensorrt')

Extracted 3 frames from video /home/lsw/Desktop/Image Shifter/dataset/video1.mp4 between 620s and 620.0666666666667s


  self.load_state_dict(torch.load(model_path))



Loaded model state dict from iteration 44734 from /home/lsw/Desktop/Image Shifter/vision_progress/unet_model_44734.pth


In [4]:
samples = video_dataset.frames[:2].permute(0, 3, 1, 2)  # Grab first 2 Elements and change from [1, H, W, C] to [1, C, H, W]

for i, sample in enumerate(samples): # tensorRT compiles on first run
    # Perform inference
    with torch.no_grad():
        samples = sample.unsqueeze(0)
        start_time = time.time()  # Record the start time
        reconstructed_frames = model(samples).to(device)  # Perform inference
        elapsed_time = (time.time() - start_time) * 1000  # Calculate the elapsed time

        print(f'{i} MODEL EXECUTION TIME: {elapsed_time}ms')

    # Remove batch dimension and permute to [H, W, C]
    sample_frame = samples[0].permute(1, 2, 0).cpu()
    reconstructed_frame = reconstructed_frames[0].permute(1, 2, 0).cpu()

# Plot the original and reconstructed frames
plt.figure(figsize=(50, 14))

plt.subplot(1, 2, 1)
plt.title("Original Frame")
plt.imshow(sample_frame)
plt.axis('off')

plt.subplot(1, 2, 2)
plt.title("Reconstructed Frame")
plt.imshow(reconstructed_frame)
plt.axis('off')

plt.savefig(f'./results/{NAME}_{MODEL_ITER}.png', bbox_inches='tight', pad_inches=0.1)

plt.show()

INFO:torch_tensorrt.dynamo.utils:Using Default Torch-TRT Runtime (as requested by user)
INFO:torch_tensorrt.dynamo.utils:Device not specified, using Torch default current device - cuda:0. If this is incorrect, please specify an input device, via the device keyword.
INFO:torch_tensorrt.dynamo.utils:Compilation Settings: CompilationSettings(enabled_precisions={<dtype.f32: 7>}, debug=False, workspace_size=0, min_block_size=5, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, refit=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=53

In [None]:
# import matplotlib.pyplot as plt

# def visualize_kernels(model, layer_index):
#     # Extract the weights of the convolutional layer
#     layer = model.encoder[layer_index]
#     if isinstance(layer, nn.Conv2d):
#         kernels = layer.weight.data.cpu().numpy()
#     else:
#         raise ValueError(f"Layer at index {layer_index} is not a Conv2d layer")

#     # Get the number of kernels (filters) and channels
#     num_filters = kernels.shape[0]
#     num_channels = kernels.shape[1]

#     # Plotting the kernels
#     fig, axes = plt.subplots(num_filters, num_channels, figsize=(num_channels, num_filters))
#     fig.subplots_adjust(hspace=0.1, wspace=0.1)

#     for i in range(num_filters):
#         for j in range(num_channels):
#             ax = axes[i, j]
#             kernel = kernels[i, j, :, :]
#             ax.imshow(kernel, cmap='gray')
#             ax.axis('off')

#     plt.show()            

# # Example: Visualize the kernels from the first convolutional layer in the encoder
# layer_index = 2  # Index of the Conv2d layer in the encoder
# visualize_kernels(model, layer_index)