In [1]:
#!sudo pip install pytube

In [2]:
#from pytube import YouTube
def download(url):
    youtubeObject = YouTube(url)
    youtubeObject = youtubeObject.streams.get_lowest_resolution()
    try:
        youtubeObject.download()
    except:
        print("An error has occurred")
    print("Download is completed successfully")

#url = "https://www.youtube.com/watch?v=FJZ-BHBKyos" # car chase
#url = "https://www.youtube.com/watch?v=1gwglom4FeA" # 8 hr nature
#url = "https://www.youtube.com/watch?v=8W1qF7l2A1c" # 6 hr nature
#download(url)

In [3]:

!pip install scikit-video

Defaulting to user installation because normal site-packages is not writeable


In [4]:
import cv2
import numpy as np
import skvideo.io  
import matplotlib.pyplot as plt

In [5]:
'''
##NOTE: use vreader and then create a downsampled np array
## try 36x64
resized_vid_arr = []
videodata = skvideo.io.vread("car_chase.mp4")  
for v in videodata:
  roi = cv2.resize(v, (36, 64))
  roi = roi.astype("float") / 255.0
  # roi = img_to_array(roi)
  # roi = np.expand_dims(roi, axis=0)
  resized_vid_arr.append(roi) 
resized_vid_arr = np.array(resized_vid_arr)
'''
resized_vid_arr = np.load("resized_vid_arr.npy")

In [6]:
resized_vid_arr.sum()

39977099.83529407

In [7]:
resized_vid_arr.shape

(18547, 64, 36, 3)

In [8]:
from moviepy.video.io.VideoFileClip import AudioFileClip
audioclip = AudioFileClip('car_chase.mp4', fps=9000)
# audioclip = videoclip.audio
audio_array = audioclip.to_soundarray()

In [9]:
video_and_audio_arr = []
aud_per_frame = audio_array.shape[0]// resized_vid_arr.shape[0]
for i in range(len(resized_vid_arr)):
  aud_in = audio_array[i*aud_per_frame : (i+1)*aud_per_frame+1]
  vec_in = [resized_vid_arr[i], aud_in]
  video_and_audio_arr.append(vec_in)

In [10]:
fps = 30
def make_frame(t):
  t = int(t * fps)
  # print(t)
  return video_and_audio_arr[t][0]

from moviepy.video.io.VideoFileClip import VideoClip
myclip = VideoClip(make_frame, duration = 20)

In [11]:
myclip.write_videofile('test.mp4', fps = 30)

Moviepy - Building video test.mp4.
Moviepy - Writing video test.mp4



                                                    

Moviepy - Done !
Moviepy - video ready test.mp4




In [12]:
## concatonate the audio
audio = video_and_audio_arr[0][1]
for i in range(20*40):
  next_aud = video_and_audio_arr[i+1][1]
  audio = np.vstack((audio, next_aud))
print(audio.shape)

(241101, 2)


In [13]:
!pip install pydub

Defaulting to user installation because normal site-packages is not writeable


In [14]:
import pydub 

def write(f, sr, x, normalized=True):
    """numpy array to MP3"""
    channels = 2 if (x.ndim == 2 and x.shape[1] == 2) else 1
    if normalized:  # normalized array - each item should be a float in [-1, 1)
        y = np.int16(x * 2 ** 15)
    else:
        y = np.int16(x)
    song = pydub.AudioSegment(y.tobytes(), frame_rate=sr, sample_width=2, channels=channels)
    song.export(f, format="mp3", bitrate="320k")

# sr = (audio_array.shape[0]// resized_vid_arr.shape[0])*fps
#write('test_audio_2.mp3', sr, audio)

*training loop from cs189 hw*

In [15]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
import gc

In [16]:
class VideoAudioDataset(Dataset):
    def __init__(self, video_frames, audio_frames, num_frames):
        self.video_frames = torch.tensor(video_frames, dtype=torch.float32).permute(0,3,2,1) # Permute to (N, C, H, W)
        self.aud_per_frame = audio_frames.shape[0]// (video_frames.shape[0])
        clip_amount = audio_frames.shape[0] % self.aud_per_frame
        self.audio_frames = torch.tensor(audio_frames[:-clip_amount], dtype=torch.float32).reshape(-1, self.aud_per_frame,audio_frames.shape[1]).permute(0,2,1) # Permute to (N, C, A)
        print(self.audio_frames.shape)
        self.num_frames = num_frames

    def __len__(self):
        return len(self.video_frames)

    def __getitem__(self, idx):
        idx += 1
        if idx < self.num_frames:
          num_zeros_needed = self.num_frames - idx
          vid_zeros = torch.zeros(num_zeros_needed, *self.video_frames[0].shape)
          aud_zeros = torch.zeros(num_zeros_needed, *self.audio_frames[0].shape)
          vid = torch.vstack((vid_zeros, self.video_frames[0:idx])).transpose(0,1)
          aud = torch.vstack((aud_zeros, self.audio_frames[0:idx])).transpose(0,1).reshape(2,-1).transpose(0,1)
          #print("aud_shape:", aud.shape, "reshaped:", torch.vstack((aud_zeros, self.audio_frames[0:idx])).transpose(0,1).shape)
          return (vid, aud)
        #vid shape example torch.Size([32, 3, 10, 36, 64])
        # aud shape example torch.Size([32, 2, 14710])
        vid = self.video_frames[idx-self.num_frames:idx].transpose(0,1)
        aud = self.audio_frames[idx-self.num_frames: (idx)].transpose(0,1).reshape(2,-1).transpose(0,1)
        #print("aud_shape:", aud.shape)
        # print('idx = ', idx, ' aud.size = ', aud.shape)
        return (vid, aud)

https://github.com/antecessor/Wavenet << source!

In [17]:
class AudConvEmbedding(nn.Module):
  def __init__(self, in_channels, output_dim):
    super(AudConvEmbedding, self).__init__()
    self.conv_layers = nn.Sequential(
        nn.Conv1d(in_channels = 2, out_channels = 2, kernel_size = 5, stride = 2),
        nn.LayerNorm(7353),
        nn.Conv1d(in_channels = 2, out_channels = 2, kernel_size = 5, stride = 2),
        nn.LayerNorm(3675),
        nn.Conv1d(in_channels = 2, out_channels = 2, kernel_size = 5, stride = 2),
        nn.LayerNorm(1836),
        nn.Conv1d(in_channels = 2, out_channels = 2, kernel_size = 5, stride = 2),
        nn.LayerNorm(916),
        nn.Conv1d(in_channels = 2, out_channels = 2, kernel_size = 5, stride = 2),
        nn.LayerNorm(456),
        nn.Linear(456, output_dim)
    )
  
  def forward(self, x):
    return self.conv_layers(x)

In [18]:
class MultiheadAttention(nn.Module):

    def __init__(self, input_dim, embed_dim, num_heads, dropout):
        super().__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be 0 modulo number of heads."

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.dropout = nn.Dropout(dropout)

        # Stack all weight matrices 1...h together for efficiency
        # Note that in many implementations you see "bias=False" which is optional
        self.qkv_proj = nn.Linear(input_dim, 3*embed_dim)
        self.o_proj = nn.Linear(embed_dim, input_dim)

        self._reset_parameters()

    def _reset_parameters(self):
        # Original Transformer initialization, see PyTorch documentation
        nn.init.xavier_uniform_(self.qkv_proj.weight)
        self.qkv_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.o_proj.weight)
        self.o_proj.bias.data.fill_(0)

    def scaled_dot_product(self, q, k, v, mask=None):
        d_k = q.size()[-1]
        attn_logits = torch.matmul(q, k.transpose(-2, -1))
        attn_logits = attn_logits / np.sqrt(d_k)
        if mask is not None:
            attn_logits = attn_logits.masked_fill(mask == 0, -9e15)
        attention = F.softmax(attn_logits, dim=-1)
        attention = self.dropout(attention)
        values = torch.matmul(attention, v)
        return values, attention

    def forward(self, x, mask=None, return_attention=False):
        batch_size, seq_length, _ = x.size()
        qkv = self.qkv_proj(x)

        # Separate Q, K, V from linear output
        qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3*self.head_dim)
        qkv = qkv.permute(0, 2, 1, 3) # [Batch, Head, SeqLen, Dims]
        q, k, v = qkv.chunk(3, dim=-1)
        #print(q.shape, k.shape, v.shape)
        # Determine value outputs
        values, attention = self.scaled_dot_product(q, k, v, mask=mask)
        values = values.permute(0, 2, 1, 3) # [Batch, SeqLen, Head, Dims]
        #print(values.shape)
        values = values.reshape(batch_size, seq_length, self.embed_dim)
        o = self.o_proj(values)

        if return_attention:
            return o, attention
        else:
            return o

In [19]:
class AddPosEncoding(nn.Module):
  def __init__(self, d_model = 256, input_dropout = 0.1, timing_dropout = 0.1, max_len = 512):
    super(AddPosEncoding, self).__init__()
    self.d_model = d_model
    self.input_dropout = input_dropout
    self.timing_dropout = timing_dropout
    self.max_len = max_len

    self.timing_table = nn.Parameter(torch.FloatTensor(max_len, d_model))
    nn.init.normal_(self.timing_table)
    self.input_dropout = nn.Dropout(input_dropout)
    self.timing_dropout = nn.Dropout(self.timing_dropout)

  
  def forward(self,x):
    x = self.input_dropout(x)
    timing = self.timing_table[None, :x.shape[1], :]
    timing = self.timing_dropout(timing)
    return x + timing 



In [20]:
class GLUTanh(nn.Module):
    def __init__(self, input_size, output_size):
        super(GLUTanh, self).__init__()
        self.glu = nn.GLU(dim=-1)
        self.linear = nn.Linear(input_size // 2, output_size)

    def forward(self, x):
        x = self.glu(x)
        x = self.linear(x)
        x = torch.tanh(x)
        return x

In [21]:
class AudTransformer(nn.Module):
  def __init__(self, in_channels, out_channels, hidden_dims = 256, seq_len = 100, dim_ff = 1024, n_layers = 3, n_head = 8, d_qkv = 64, dropout = 0.1):
        super(AudTransformer, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.hidden_dims = hidden_dims
        self.dim_ff = dim_ff
        self.n_layers = n_layers
        self.n_head = n_head
        self.d_qkv = d_qkv
        self.dropout = dropout

        self.mha_list = nn.ModuleList()
        self.mha_norms = nn.ModuleList()
        self.pff_list = nn.ModuleList()
        self.pff_norms = nn.ModuleList()

        #self.output_norm = nn.LayerNorm(hidden_dims)
        #self.linear = nn.Linear(self.hidden_dims, 2)
        #self.tanh = nn.Tanh()

        self.add_timing = AddPosEncoding(hidden_dims, max_len = seq_len)

        #self.embedding = AudConvEmbedding(2, self.hidden_dims)
        self.embedding2 = nn.Linear(2, self.hidden_dims)
        #self.output_proj = nn.Conv1d(in_channels = seq_len, out_channels = 1, kernel_size = 1, stride = 1)

        self.output_proj = nn.Sequential(
            nn.LayerNorm(hidden_dims),
            nn.Conv1d(in_channels = seq_len, out_channels = 2, kernel_size = 1, stride = 1),
            nn.Linear(self.hidden_dims, 1024),
            nn.GLU(),
            nn.Linear(512, 256),
        )
        

        for _ in range(n_layers):
          self.mha_list.append(MultiheadAttention(self.hidden_dims, self.d_qkv, self.n_head, self.dropout))
          self.mha_norms.append(nn.LayerNorm(hidden_dims))
          self.pff_list.append(nn.Sequential(
                              nn.Linear(hidden_dims, dim_ff),
                              nn.ReLU(),
                              nn.Linear(dim_ff, hidden_dims),
                              nn.Dropout(self.dropout)
                              ))
          self.pff_norms.append(nn.LayerNorm(hidden_dims))

  def forward(self, x):
    #print("tf:", x.shape)
    #x = self.embedding(x).transpose(2,1)
    x = self.embedding2(x)
    x = x + self.add_timing(x)
    #print("tf emb:", x.shape)
    for i in range(self.n_layers):
      x_norm = self.mha_norms[i](x)
      #print('x_norm.shape', x_norm.shape)
      att_out = self.mha_list[i](x)
      #print(att_out.shape)
      x = x_norm + att_out
      x_norm = self.pff_norms[i](x)
      ff_out = self.pff_list[i](x_norm)
      x = x_norm + ff_out
      #print("tf layer:", i, x.shape)
    
    #x = self.output_norm(x)
    #print("tf:", x.shape)
    #x = self.linear(x)
    x = self.output_proj(x)
    #print("tf:", x.shape)
    #print("tf:", x.shape)
    return x


In [22]:
class VidToAudFusion(nn.Module):
    def __init__(self, audio_dim, video_dims, audio_channels, vid_channels, dropout_prob=0.1, num_frames=10):
        super(VidToAudFusion, self).__init__()
        self.dim_proj1 = nn.Linear(video_dims[0] * video_dims[1], 200)
        self.dim_proj2 = nn.Linear(200 * num_frames, (audio_dim) - 1)
        self.audio_dim = audio_dim
        self.video_dims = video_dims
        self.vid_channels = vid_channels

        if audio_channels == vid_channels:
            self.channel_projection = nn.Identity()
        else:
            self.channel_projection = nn.Sequential(
                nn.Conv2d(vid_channels, audio_channels, kernel_size=1, stride=1, bias=True),
            )
        self.channel_projection2 = nn.Sequential(
                nn.Conv1d(num_frames, 2, kernel_size=1, stride=1, bias=True),
            )
            
    def forward(self, x):
      #x = x.reshape(x.shape[0], x.shape[1], x.shape[2], -1)
      #print("vid to aud:", x.shape)
      audio_proj = torch.flatten(x, start_dim=3)
      #print("vid to aud:", x.shape)
      #print(audio_proj.shape)  
      audio_proj = self.dim_proj1(audio_proj)
      #print("vid to aud:", audio_proj.shape)
      #print(audio_proj.shape)
      audio_proj = self.channel_projection(audio_proj).squeeze(2)
      #print("vid to aud:", audio_proj.shape)
      #audio_proj = self.channel_projection2(audio_proj)
      audio_proj = audio_proj.reshape(audio_proj.shape[0], -1)
      #print("vid to aud:", audio_proj.shape)
      audio_proj = self.dim_proj2(audio_proj)
      #print("vid to aud:", audio_proj.shape)
      #print(audio_proj.shape)
      return audio_proj

In [23]:
class ConvResBlock3D(nn.Module):
    def __init__(self, in_channels, out_channels, dropout_prob=0.1, stride=1, kernel_size=3, padding=1):
        super(ConvResBlock3D, self).__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=True)
        self.norm1 = nn.BatchNorm3d(out_channels)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=dropout_prob)
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding, bias=True)
        self.norm2 = nn.BatchNorm3d(out_channels)
        self.stride = stride
        
        if in_channels == out_channels:
            self.skip_connection = nn.Identity()
        else:
            self.skip_connection = nn.Sequential(
                nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=stride, bias=True),
                nn.BatchNorm3d(out_channels)
            )
            
    def forward(self, x):
        residual = self.skip_connection(x)
        
        out = self.conv1(x)
        out = self.norm1(out)
        out = self.relu(out)
        out = self.dropout(out)
        
        out = self.conv2(out)
        out = self.norm2(out)
        out = self.dropout(out)
        
        out = self.relu(out)
        out = out + residual

        return out

class ConvResBlock2D(nn.Module):
    def __init__(self, in_channels, out_channels, dropout_prob=0.1, stride=1, kernel_size=3, padding=1):
        super(ConvResBlock2D, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=True)
        self.norm1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=dropout_prob)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding, bias=True)
        self.norm2 = nn.BatchNorm2d(out_channels)
        self.stride = stride
        
        if in_channels == out_channels:
            self.skip_connection = nn.Identity()
        else:
            self.skip_connection = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=True),
                nn.BatchNorm2d(out_channels)
            )
            
    def forward(self, x):
        residual = self.skip_connection(x)
        
        out = self.conv1(x)
        out = self.norm1(out)
        out = self.relu(out)
        out = self.dropout(out)
        
        out = self.conv2(out)
        out = self.norm2(out)
        out = self.dropout(out)
        
        out = self.relu(out)
        out = out + residual

        return out

In [24]:
### combines video into audio input vector and then passes through wavenet

class AttAudVideoNet(nn.Module):
    def __init__(self,audio_input_shape, video_input_shape,in_channels=2,out_channels=2, seq_len=1470, num_frames=5): ## NEED TO DEBUG THESE hyperperams 4.29
        super().__init__()
        self.transformer=AudTransformer(in_channels, out_channels, hidden_dims = 256, seq_len = seq_len, dim_ff = 1024, n_layers = 5, n_head = 8, d_qkv = 64, dropout = 0.1) 
        #self.activation = GLUTanh(4,2)
        self.activation = nn.Softmax(dim=2)
        self.vid_convs = nn.ModuleList([
            ConvResBlock3D(3, 128),
            ConvResBlock3D(128, 256),
            ConvResBlock3D(256, 64),
            ConvResBlock3D(64, 8)
        ])
        self.audio_input_shape = audio_input_shape
        self.video_input_shape = video_input_shape
        vid_dims_list = self.get_video_dims(video_input_shape)
        self.vid_to_aud = VidToAudFusion(self.audio_input_shape[-1], vid_dims_list[2], 1, 8, num_frames=num_frames)
        
        #self.output_lin = nn.Linear(256*in_channels, out_channels)

    def get_video_dims(self, video_input_shape):
        shape_list = []
        test_data = torch.ones(video_input_shape)
        out = self.vid_convs[0](test_data)
        shape_list.append((out.shape[-2],out.shape[-1]))
        for layer in self.vid_convs[1:]:
            out = layer(out)
            shape_list.append((out.shape[-2],out.shape[-1]))
        return shape_list

    def forward(self,vid, aud):
        #print(vid.shape, aud.shape)
        for i in range(len(self.vid_convs)):
          vid = self.vid_convs[i](vid)
        #print(vid.shape)
        vid_to_aud = self.vid_to_aud(vid)
        #print(vid_to_aud.shape, aud.shape)
        #print(vid_to_aud.shape, aud.shape)
        aud = aud + vid_to_aud.unsqueeze(2)
        aud=self.transformer(aud)
        #print(aud.shape)
        #aud=aud.reshape(aud.shape[0], -1)
        #aud = self.output_lin(aud)
        #print('output shape = ', aud.shape)
        return self.activation(aud)



In [25]:
print(resized_vid_arr.shape, resized_vid_arr.size)
print(audio_array.size)

(18547, 64, 36, 3) 128196864
11139480


In [26]:
num_frames = 5
train_data = VideoAudioDataset(resized_vid_arr[:int(len(resized_vid_arr)*0.8)], audio_array[:int(len(audio_array)*0.8)], num_frames = num_frames)
valid_data = VideoAudioDataset(resized_vid_arr[int(len(resized_vid_arr)*0.8):], audio_array[int(len(audio_array)*0.8):], num_frames = num_frames)
train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=32, shuffle=True)

torch.Size([14852, 2, 300])
torch.Size([3713, 2, 300])


In [27]:
for x in train_loader:
  print(x[0].shape)
  print(x[1].shape)
  break
  
count = 0
# for x in valid_loader:
#   count += 1
#   if count == 10:
#     print(x[0].shape)
#     print(x[1].shape)
#     break

torch.Size([16, 3, 5, 36, 64])
torch.Size([16, 1500, 2])


In [28]:
step_size = 2/256
test = np.arange(-1,1,step_size) + step_size
print(test.shape)
print(test)

(256,)
[-0.9921875 -0.984375  -0.9765625 -0.96875   -0.9609375 -0.953125
 -0.9453125 -0.9375    -0.9296875 -0.921875  -0.9140625 -0.90625
 -0.8984375 -0.890625  -0.8828125 -0.875     -0.8671875 -0.859375
 -0.8515625 -0.84375   -0.8359375 -0.828125  -0.8203125 -0.8125
 -0.8046875 -0.796875  -0.7890625 -0.78125   -0.7734375 -0.765625
 -0.7578125 -0.75      -0.7421875 -0.734375  -0.7265625 -0.71875
 -0.7109375 -0.703125  -0.6953125 -0.6875    -0.6796875 -0.671875
 -0.6640625 -0.65625   -0.6484375 -0.640625  -0.6328125 -0.625
 -0.6171875 -0.609375  -0.6015625 -0.59375   -0.5859375 -0.578125
 -0.5703125 -0.5625    -0.5546875 -0.546875  -0.5390625 -0.53125
 -0.5234375 -0.515625  -0.5078125 -0.5       -0.4921875 -0.484375
 -0.4765625 -0.46875   -0.4609375 -0.453125  -0.4453125 -0.4375
 -0.4296875 -0.421875  -0.4140625 -0.40625   -0.3984375 -0.390625
 -0.3828125 -0.375     -0.3671875 -0.359375  -0.3515625 -0.34375
 -0.3359375 -0.328125  -0.3203125 -0.3125    -0.3046875 -0.296875
 -0.2890625 -0

In [29]:
rand = np.random.uniform(-1,1, size=10).reshape(-1,1)
print(rand, rand.shape)

[[ 0.61538247]
 [-0.55269147]
 [ 0.70479517]
 [ 0.96814845]
 [ 0.04396678]
 [ 0.42740758]
 [-0.53049681]
 [-0.52506866]
 [ 0.0415665 ]
 [-0.45119573]] (10, 1)


In [30]:
# In order to be assigned bucket [i], the value must be geq than i -1 and less than i
tttt = np.array([-1.0001, -1, -.999, -.99, -.98, .99, .999, 1, 1.00001])

In [31]:
np.digitize(tttt, test)

array([  0,   0,   0,   1,   2, 254, 255, 256, 256])

In [32]:
np.digitize(rand, test)

array([[206],
       [ 57],
       [218],
       [251],
       [133],
       [182],
       [ 60],
       [ 60],
       [133],
       [ 70]])

In [33]:
# Quantization test
test2 = np.arange(-1,1,1/256)
print(np.digitize(test2, test))

[  0   0   1   1   2   2   3   3   4   4   5   5   6   6   7   7   8   8
   9   9  10  10  11  11  12  12  13  13  14  14  15  15  16  16  17  17
  18  18  19  19  20  20  21  21  22  22  23  23  24  24  25  25  26  26
  27  27  28  28  29  29  30  30  31  31  32  32  33  33  34  34  35  35
  36  36  37  37  38  38  39  39  40  40  41  41  42  42  43  43  44  44
  45  45  46  46  47  47  48  48  49  49  50  50  51  51  52  52  53  53
  54  54  55  55  56  56  57  57  58  58  59  59  60  60  61  61  62  62
  63  63  64  64  65  65  66  66  67  67  68  68  69  69  70  70  71  71
  72  72  73  73  74  74  75  75  76  76  77  77  78  78  79  79  80  80
  81  81  82  82  83  83  84  84  85  85  86  86  87  87  88  88  89  89
  90  90  91  91  92  92  93  93  94  94  95  95  96  96  97  97  98  98
  99  99 100 100 101 101 102 102 103 103 104 104 105 105 106 106 107 107
 108 108 109 109 110 110 111 111 112 112 113 113 114 114 115 115 116 116
 117 117 118 118 119 119 120 120 121 121 122 122 12

In [None]:
from tqdm.notebook import tqdm, trange

# wavenet = WaveNet(in_channels=2,out_channels=2,kernel_size=2,stack_size=23,layer_size=6).cuda().train()
#audio_test_data = torch.ones((1,2,1471))
seq_len = 300 * num_frames
audio_test_data = torch.ones((1,2,seq_len))
vid_test_data = torch.ones((1,3,num_frames,36,64))

wavenet = AttAudVideoNet(audio_input_shape=audio_test_data.shape, video_input_shape=vid_test_data.shape,in_channels=2,out_channels=2, seq_len=seq_len-1, num_frames=num_frames).cuda().train()
#load_path = 'transformer_vid_model3_5.7.pt'
#wavenet.load_state_dict(torch.load(load_path))

lr = 4e-5
epochs= 50
globalStep=1000
aud_quantization_factor=256

quant_step_size = 2/aud_quantization_factor
quant_bins = np.arange(-1,1,quant_step_size) + quant_step_size

optimizer=torch.optim.AdamW(wavenet.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
  optimizer,
  lr,
  epochs=epochs,
  steps_per_epoch=len(train_loader),
  pct_start=0.03,  # Warm up for 3% of the total training time
  )
lossFunction = torch.nn.CrossEntropyLoss()

def calc_accuracy(Out,Y):
    max_vals, max_indices = torch.max(Out,1)
    train_acc = (max_indices == Y).sum().item()/max_indices.size()[0]
    return train_acc
  


for epoch in range(epochs):
    for step, (vid_frames,aud_frames) in tqdm(enumerate(train_loader),desc="Training"):
         #vid_frames = vid_frames.cuda()
         target = aud_frames[:,-1,].cpu().numpy()
         aud_frames = aud_frames[:,:-1,:].cuda()
         vid_frames = vid_frames.cuda()
         #print(target.shape)
         #print(aud_frames.shape)
         output = wavenet(vid_frames, aud_frames).squeeze()
         #print('ttt', output.shape, target.shape)
         #print(output[0].detach().cpu().numpy(), target[0].cpu().numpy())
         #print(output.shape)
         #print(output)
         #print(output.dtype, target.dtype)
         quant_target = np.digitize(target, quant_bins)
         quant_target = torch.tensor(quant_target).cuda()
         #print(quant_target.shape, output.shape)
         #print(quant_target)
         #print(output.reshape(-1, aud_quantization_factor).shape)
         #print(quant_target.reshape(-1).shape)
         loss = lossFunction(output.reshape(-1, aud_quantization_factor), quant_target.reshape(-1))
         optimizer.zero_grad()
         loss.backward()
         optimizer.step()
         scheduler.step()
         if step%globalStep==0:
            # scheduler.step()
            # print(output.detach().numpy())
            # print(y_train.numpy())
            with torch.no_grad():
                accuracy=0
                val_loss=0
                for stepTest, (vid_frames,aud_frames) in tqdm(enumerate(valid_loader),desc="Validation"):
                    vid_frames = vid_frames.cuda()
                    target = aud_frames[:,-1,].cpu().numpy()
                    aud_frames = aud_frames[:,:-1,:].cuda()
                    output = wavenet(vid_frames, aud_frames).squeeze()
                    quant_target = np.digitize(target, quant_bins)
                    quant_target = torch.tensor(quant_target).cuda()
                    if stepTest==0:
                        print(target[:3])
                        print(output[:3])
                        print(quant_target[:3])
                        print(output.argmax(dim=2)[:3])
                        #print(lossFunction(output[0],target[0]).item())
                    #accuracy+=calc_accuracy(output,target)*100
                    val_loss+= lossFunction(output.reshape(-1, aud_quantization_factor), quant_target.reshape(-1)).item()
                    if stepTest>200:
                        print(output)
                        break
            print(f"loss for step {step} : {val_loss/stepTest}")

         
    print(f"epoch {epoch}")

    save_path = 'transformer_vid_model3_5.10.pt'
    torch.save(wavenet.state_dict(), save_path)

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[ 0.28152466 -0.21252441]
 [-0.19137573 -0.24133301]
 [ 0.09222412 -0.04415894]]
tensor([[[0.0030, 0.0050, 0.0060,  ..., 0.0037, 0.0039, 0.0054],
         [0.0037, 0.0035, 0.0035,  ..., 0.0035, 0.0037, 0.0031]],

        [[0.0031, 0.0052, 0.0057,  ..., 0.0036, 0.0041, 0.0053],
         [0.0037, 0.0034, 0.0034,  ..., 0.0036, 0.0036, 0.0030]],

        [[0.0033, 0.0051, 0.0058,  ..., 0.0034, 0.0040, 0.0054],
         [0.0037, 0.0035, 0.0036,  ..., 0.0035, 0.0036, 0.0031]]],
       device='cuda:0')
tensor([[164, 100],
        [103,  97],
        [139, 122]], device='cuda:0')
tensor([[197,  47],
        [  2,  23],
        [ 52,  47]], device='cuda:0')
loss for step 0 : 5.593390369415284
epoch 0


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[-0.01123047 -0.02890015]
 [-0.0473938   0.00872803]
 [ 0.07843018  0.00585938]]
tensor([[[3.0696e-07, 3.9486e-08, 1.3495e-07,  ..., 2.5805e-07,
          1.6711e-07, 1.2001e-07],
         [7.7865e-07, 1.3181e-07, 4.8899e-07,  ..., 8.1642e-07,
          4.9567e-07, 4.0263e-07]],

        [[4.2931e-07, 4.6601e-08, 1.6808e-07,  ..., 3.1452e-07,
          3.0135e-07, 2.0298e-07],
         [2.1241e-06, 3.2970e-07, 1.1183e-06,  ..., 1.6424e-06,
          1.5886e-06, 1.0434e-06]],

        [[2.7968e-07, 2.8337e-08, 1.0839e-07,  ..., 1.8592e-07,
          1.6921e-07, 1.2872e-07],
         [1.3115e-06, 1.8564e-07, 7.4493e-07,  ..., 9.8978e-07,
          8.7907e-07, 6.3237e-07]]], device='cuda:0')
tensor([[126, 124],
        [121, 129],
        [138, 128]], device='cuda:0')
tensor([[127, 127],
        [127, 127],
        [127, 127]], device='cuda:0')
loss for step 0 : 5.585195412843124
epoch 1


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[-0.0189209   0.09683228]
 [-0.10626221 -0.20648193]
 [ 0.24554443  0.36428833]]
tensor([[[8.7275e-08, 1.2836e-08, 4.4038e-08,  ..., 9.9992e-08,
          1.3364e-07, 4.3337e-08],
         [2.4346e-09, 2.6963e-10, 1.0814e-09,  ..., 2.9435e-09,
          4.0816e-09, 1.1238e-09]],

        [[7.4520e-08, 1.0844e-08, 4.5206e-08,  ..., 1.0185e-07,
          1.3111e-07, 4.3665e-08],
         [2.7919e-09, 3.1908e-10, 1.5396e-09,  ..., 3.9264e-09,
          5.2600e-09, 1.5251e-09]],

        [[8.7083e-08, 1.2640e-08, 4.9594e-08,  ..., 1.1533e-07,
          1.6316e-07, 4.9404e-08],
         [5.5265e-10, 3.5605e-11, 2.8364e-10,  ..., 8.0539e-10,
          1.2055e-09, 2.8994e-10]]], device='cuda:0')
tensor([[125, 140],
        [114, 101],
        [159, 174]], device='cuda:0')
tensor([[129, 129],
        [129, 129],
        [129, 129]], device='cuda:0')
loss for step 0 : 5.586148577151091
epoch 2


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[ 0.1741333   0.0960083 ]
 [ 0.05996704 -0.16375732]
 [ 0.2135315   0.04318237]]
tensor([[[3.7463e-09, 2.5870e-10, 1.6059e-09,  ..., 4.3825e-09,
          6.9203e-09, 1.5915e-09],
         [8.3214e-11, 3.9634e-12, 3.0924e-11,  ..., 1.0252e-10,
          1.5653e-10, 3.1698e-11]],

        [[7.4392e-09, 4.4267e-10, 2.8625e-09,  ..., 6.5435e-09,
          1.0440e-08, 2.4291e-09],
         [1.0941e-10, 3.2503e-12, 3.3834e-11,  ..., 9.2670e-11,
          1.5591e-10, 3.0258e-11]],

        [[5.4328e-09, 3.8230e-10, 2.2304e-09,  ..., 5.6219e-09,
          8.5277e-09, 2.1620e-09],
         [8.0904e-11, 3.1897e-12, 2.6436e-11,  ..., 8.7576e-11,
          1.3959e-10, 2.7533e-11]]], device='cuda:0')
tensor([[150, 140],
        [135, 107],
        [155, 133]], device='cuda:0')
tensor([[129, 129],
        [129, 129],
        [129, 129]], device='cuda:0')
loss for step 0 : 5.586130341239597
epoch 3


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[ 0.22195435  0.36376953]
 [-0.24035645  0.13388062]
 [ 0.5795288   0.32125854]]
tensor([[[4.0407e-09, 2.4693e-10, 1.5725e-09,  ..., 3.8307e-09,
          5.9221e-09, 1.4144e-09],
         [5.3488e-11, 1.7873e-12, 1.6309e-11,  ..., 4.9854e-11,
          7.8207e-11, 1.6120e-11]],

        [[4.5462e-09, 2.8289e-10, 2.0664e-09,  ..., 4.6103e-09,
          6.7020e-09, 1.7417e-09],
         [4.8564e-11, 1.4355e-12, 1.7393e-11,  ..., 4.4707e-11,
          7.2615e-11, 1.5693e-11]],

        [[2.7142e-09, 1.4131e-10, 1.0608e-09,  ..., 2.7058e-09,
          4.0478e-09, 9.6176e-10],
         [1.1116e-10, 5.1384e-12, 3.6460e-11,  ..., 1.0729e-10,
          1.8012e-10, 3.7649e-11]]], device='cuda:0')
tensor([[156, 174],
        [ 97, 145],
        [202, 169]], device='cuda:0')
tensor([[129, 129],
        [129, 129],
        [129, 129]], device='cuda:0')
loss for step 0 : 5.586148465198019
epoch 4


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[0.3355713  0.22140503]
 [0.00915527 0.30325317]
 [0.02871704 0.28811646]]
tensor([[[1.9412e-09, 9.5312e-11, 7.5549e-10,  ..., 1.7858e-09,
          3.0409e-09, 6.9596e-10],
         [1.6051e-11, 3.0977e-13, 4.8954e-12,  ..., 1.4830e-11,
          2.2697e-11, 3.9818e-12]],

        [[1.9731e-09, 1.1418e-10, 8.5009e-10,  ..., 1.9267e-09,
          3.0943e-09, 7.9169e-10],
         [1.2761e-11, 3.0897e-13, 4.3403e-12,  ..., 1.2579e-11,
          2.1484e-11, 4.2964e-12]],

        [[3.0615e-09, 1.7940e-10, 1.3131e-09,  ..., 2.5727e-09,
          4.1845e-09, 1.1182e-09],
         [1.2727e-10, 5.9893e-12, 4.3444e-11,  ..., 1.0121e-10,
          1.6558e-10, 3.9880e-11]]], device='cuda:0')
tensor([[170, 156],
        [129, 166],
        [131, 164]], device='cuda:0')
tensor([[129, 129],
        [129, 129],
        [129, 129]], device='cuda:0')
loss for step 0 : 5.586139455048934
epoch 5


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[ 0.11004639  0.21054077]
 [ 0.11578369 -0.19717407]
 [-0.19326782 -0.17941284]]
tensor([[[1.1235e-09, 4.9339e-11, 4.6049e-10,  ..., 1.0525e-09,
          1.6159e-09, 3.7092e-10],
         [1.3888e-11, 3.2586e-13, 4.2521e-12,  ..., 1.2430e-11,
          2.0743e-11, 3.9534e-12]],

        [[1.4063e-09, 6.0409e-11, 5.3648e-10,  ..., 1.1744e-09,
          2.0012e-09, 4.3604e-10],
         [1.7708e-11, 4.3641e-13, 4.9330e-12,  ..., 1.4967e-11,
          2.6163e-11, 4.6438e-12]],

        [[1.5077e-09, 6.4697e-11, 5.3797e-10,  ..., 1.2658e-09,
          2.0097e-09, 4.7594e-10],
         [1.5527e-11, 3.1158e-13, 3.9614e-12,  ..., 1.3330e-11,
          2.1235e-11, 4.0770e-12]]], device='cuda:0')
tensor([[142, 154],
        [142, 102],
        [103, 105]], device='cuda:0')
tensor([[129, 129],
        [129, 129],
        [129, 129]], device='cuda:0')
loss for step 0 : 5.586139372120733
epoch 6


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[-0.05337524 -0.11514282]
 [-0.1789856   0.08416748]
 [-0.33010864 -0.29556274]]
tensor([[[2.3891e-09, 1.3121e-10, 8.2999e-10,  ..., 2.4062e-09,
          3.2647e-09, 7.3472e-10],
         [4.3776e-11, 1.1145e-12, 1.1458e-11,  ..., 3.9192e-11,
          5.4281e-11, 9.6577e-12]],

        [[9.8749e-10, 4.0824e-11, 3.5179e-10,  ..., 9.3560e-10,
          1.5502e-09, 3.5403e-10],
         [2.4946e-11, 7.4953e-13, 7.3714e-12,  ..., 2.3442e-11,
          3.7590e-11, 7.4215e-12]],

        [[5.3917e-10, 1.9389e-11, 2.0193e-10,  ..., 6.1361e-10,
          8.1550e-10, 1.7633e-10],
         [7.6331e-12, 1.4285e-13, 2.0157e-12,  ..., 8.6193e-12,
          1.0836e-11, 2.0614e-12]]], device='cuda:0')
tensor([[121, 113],
        [105, 138],
        [ 85,  90]], device='cuda:0')
tensor([[129, 129],
        [129, 129],
        [129, 129]], device='cuda:0')
loss for step 0 : 5.586148448612379
epoch 7


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[-0.02706909 -0.03820801]
 [ 0.31582642  0.44128418]
 [-0.18600464 -0.14804077]]
tensor([[[4.1385e-10, 9.5618e-12, 1.2249e-10,  ..., 2.9932e-10,
          4.8518e-10, 8.1592e-11],
         [4.0465e-12, 4.1100e-14, 9.3093e-13,  ..., 2.6907e-12,
          4.6771e-12, 5.8230e-13]],

        [[8.7997e-10, 2.5609e-11, 2.4138e-10,  ..., 5.5256e-10,
          8.3404e-10, 1.9035e-10],
         [4.2523e-12, 3.8066e-14, 7.2503e-13,  ..., 2.1785e-12,
          3.8774e-12, 5.5697e-13]],

        [[7.5116e-10, 2.3294e-11, 2.1135e-10,  ..., 4.7895e-10,
          8.2988e-10, 1.7089e-10],
         [3.2560e-12, 3.1268e-14, 6.2007e-13,  ..., 1.9392e-12,
          3.5528e-12, 4.7322e-13]]], device='cuda:0')
tensor([[124, 123],
        [168, 184],
        [104, 109]], device='cuda:0')
tensor([[129, 129],
        [129, 129],
        [129, 129]], device='cuda:0')
loss for step 0 : 5.586139268460481
epoch 8


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[ 0.08731079 -0.18167114]
 [-0.18060303 -0.19528198]
 [-0.04034424 -0.11978149]]
tensor([[[4.9175e-10, 1.1795e-11, 1.4210e-10,  ..., 3.3174e-10,
          5.1133e-10, 1.0876e-10],
         [3.7960e-12, 3.6384e-14, 8.8957e-13,  ..., 2.3664e-12,
          3.8216e-12, 6.3020e-13]],

        [[5.4440e-10, 1.3466e-11, 1.5850e-10,  ..., 3.5530e-10,
          5.5591e-10, 1.1793e-10],
         [5.8324e-12, 6.3224e-14, 1.3165e-12,  ..., 3.4767e-12,
          6.0491e-12, 9.7788e-13]],

        [[1.0799e-09, 4.2953e-11, 3.4055e-10,  ..., 7.2937e-10,
          1.1610e-09, 2.8098e-10],
         [7.5525e-12, 1.1017e-13, 1.8247e-12,  ..., 4.7113e-12,
          7.9686e-12, 1.3729e-12]]], device='cuda:0')
tensor([[139, 104],
        [104, 103],
        [122, 112]], device='cuda:0')
tensor([[129, 129],
        [129, 129],
        [129, 129]], device='cuda:0')
loss for step 0 : 5.586148365684178
epoch 9


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[ 0.04129028 -0.3564148 ]
 [-0.37277222 -0.37979126]
 [ 0.29840088  0.26623535]]
tensor([[[2.9099e-10, 5.8189e-12, 7.7133e-11,  ..., 1.7529e-10,
          3.0324e-10, 5.8031e-11],
         [1.3676e-12, 1.0332e-14, 2.6156e-13,  ..., 8.0127e-13,
          1.4999e-12, 1.9505e-13]],

        [[5.4033e-10, 1.4252e-11, 1.3689e-10,  ..., 3.1718e-10,
          6.0064e-10, 1.1483e-10],
         [2.5822e-12, 2.2071e-14, 4.5707e-13,  ..., 1.4156e-12,
          2.9210e-12, 3.8773e-13]],

        [[2.8880e-10, 7.2969e-12, 8.7937e-11,  ..., 2.3392e-10,
          3.4342e-10, 7.4261e-11],
         [1.5627e-12, 1.3474e-14, 3.6930e-13,  ..., 1.2732e-12,
          1.8832e-12, 2.8878e-13]]], device='cuda:0')
tensor([[133,  82],
        [ 80,  79],
        [166, 162]], device='cuda:0')
tensor([[129, 129],
        [129, 129],
        [129, 129]], device='cuda:0')
loss for step 0 : 5.586139467488165
epoch 10


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[-0.23110962 -0.24975586]
 [-0.03390503 -0.28915405]
 [ 0.07885742  0.10482788]]
tensor([[[9.6968e-10, 3.5044e-11, 2.7987e-10,  ..., 6.3144e-10,
          9.7737e-10, 2.4737e-10],
         [3.4461e-12, 4.0495e-14, 7.3380e-13,  ..., 2.1207e-12,
          3.6112e-12, 5.9283e-13]],

        [[4.2670e-10, 1.0733e-11, 1.1557e-10,  ..., 2.5403e-10,
          3.9424e-10, 8.3964e-11],
         [3.2180e-12, 2.7372e-14, 5.8189e-13,  ..., 1.5501e-12,
          2.6241e-12, 4.1253e-13]],

        [[2.3071e-10, 5.6479e-12, 6.5608e-11,  ..., 1.6420e-10,
          2.5899e-10, 4.8990e-11],
         [1.2245e-12, 1.1173e-14, 2.4760e-13,  ..., 8.8905e-13,
          1.3956e-12, 1.7879e-13]]], device='cuda:0')
tensor([[ 98,  96],
        [123,  90],
        [138, 141]], device='cuda:0')
tensor([[129, 129],
        [129, 129],
        [129, 129]], device='cuda:0')
loss for step 0 : 5.586130453192669
epoch 11


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[ 0.41290283  0.41540527]
 [-0.06396484 -0.00579834]
 [ 0.02078247 -0.0100708 ]]
tensor([[[1.1541e-10, 2.7218e-12, 3.1747e-11,  ..., 9.4504e-11,
          1.4066e-10, 2.7384e-11],
         [1.5393e-12, 1.7428e-14, 3.1962e-13,  ..., 1.2210e-12,
          1.7235e-12, 2.8941e-13]],

        [[5.0116e-10, 1.7252e-11, 1.3590e-10,  ..., 3.6344e-10,
          5.3057e-10, 1.2054e-10],
         [1.9827e-12, 2.1476e-14, 4.3157e-13,  ..., 1.4634e-12,
          2.3318e-12, 3.7192e-13]],

        [[1.8710e-10, 3.3534e-12, 5.0413e-11,  ..., 1.1557e-10,
          1.8432e-10, 3.2258e-11],
         [2.6186e-12, 2.1964e-14, 5.2347e-13,  ..., 1.4414e-12,
          2.3108e-12, 3.4243e-13]]], device='cuda:0')
tensor([[180, 181],
        [119, 127],
        [130, 126]], device='cuda:0')
tensor([[129, 129],
        [129, 129],
        [129, 129]], device='cuda:0')
loss for step 0 : 5.586148577151091
epoch 12


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[-0.0987854   0.21414185]
 [ 0.08963013  0.10476685]
 [-0.03094482 -0.04537964]]
tensor([[[9.1678e-10, 3.4807e-11, 2.9274e-10,  ..., 5.7021e-10,
          8.2603e-10, 2.1612e-10],
         [5.2328e-12, 6.1073e-14, 1.1531e-12,  ..., 2.8432e-12,
          4.2241e-12, 8.1207e-13]],

        [[3.4119e-10, 9.2398e-12, 1.0832e-10,  ..., 2.5536e-10,
          3.6368e-10, 7.8776e-11],
         [2.2386e-12, 1.9260e-14, 5.6725e-13,  ..., 1.5247e-12,
          2.2878e-12, 3.6979e-13]],

        [[2.0559e-10, 4.6924e-12, 6.2569e-11,  ..., 1.4259e-10,
          2.3681e-10, 4.6842e-11],
         [1.8904e-12, 1.7754e-14, 4.4719e-13,  ..., 1.1948e-12,
          2.0302e-12, 3.2117e-13]]], device='cuda:0')
tensor([[115, 155],
        [139, 141],
        [124, 122]], device='cuda:0')
tensor([[129, 129],
        [129, 129],
        [129, 129]], device='cuda:0')
loss for step 0 : 5.586148577151091
epoch 13


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[-0.21951294  0.6246338 ]
 [ 0.06604004  0.30316162]
 [-0.01754761  0.00741577]]
tensor([[[2.1042e-10, 3.5414e-12, 4.8464e-11,  ..., 1.2383e-10,
          1.6528e-10, 3.6483e-11],
         [1.4970e-12, 9.3877e-15, 2.2948e-13,  ..., 8.3815e-13,
          1.0946e-12, 1.7443e-13]],

        [[4.6484e-10, 1.3095e-11, 1.3637e-10,  ..., 3.0384e-10,
          4.2727e-10, 1.0811e-10],
         [2.1369e-12, 1.6687e-14, 4.2313e-13,  ..., 1.2732e-12,
          1.8474e-12, 3.3714e-13]],

        [[1.1952e-10, 2.1010e-12, 3.0999e-11,  ..., 8.3362e-11,
          1.3234e-10, 2.4886e-11],
         [1.7888e-12, 1.6949e-14, 3.8651e-13,  ..., 1.2771e-12,
          1.8040e-12, 3.0337e-13]]], device='cuda:0')
tensor([[ 99, 207],
        [136, 166],
        [125, 128]], device='cuda:0')
tensor([[129, 129],
        [129, 129],
        [129, 129]], device='cuda:0')
loss for step 0 : 5.586139513098675
epoch 14


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[ 0.16845703  0.07034302]
 [ 0.1399231   0.42370605]
 [-0.00747681  0.07336426]]
tensor([[[1.9930e-10, 5.3593e-12, 6.0103e-11,  ..., 1.6606e-10,
          2.2874e-10, 4.9855e-11],
         [2.7236e-12, 3.4139e-14, 6.7966e-13,  ..., 1.9514e-12,
          2.7466e-12, 5.2295e-13]],

        [[1.8804e-10, 4.3811e-12, 5.5491e-11,  ..., 1.4289e-10,
          1.9159e-10, 4.4389e-11],
         [1.8116e-12, 1.8337e-14, 4.1022e-13,  ..., 1.4032e-12,
          1.7023e-12, 3.2448e-13]],

        [[2.3396e-10, 5.4094e-12, 6.5128e-11,  ..., 1.5511e-10,
          2.0751e-10, 5.1758e-11],
         [1.5844e-12, 1.3153e-14, 3.1467e-13,  ..., 9.8231e-13,
          1.3731e-12, 2.5331e-13]]], device='cuda:0')
tensor([[149, 137],
        [145, 182],
        [127, 137]], device='cuda:0')
tensor([[129, 129],
        [129, 129],
        [129, 129]], device='cuda:0')
loss for step 0 : 5.5861213974330735
epoch 15


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[-0.22128296 -0.14950562]
 [-0.10498047 -0.04989624]
 [-0.6607971  -0.58410645]]
tensor([[[2.4709e-11, 4.2368e-13, 6.6282e-12,  ..., 1.2262e-11,
          1.7624e-11, 4.0724e-12],
         [1.5545e-12, 1.5467e-14, 3.7002e-13,  ..., 7.2980e-13,
          1.1016e-12, 2.1538e-13]],

        [[2.2512e-11, 4.5180e-13, 6.3692e-12,  ..., 1.3198e-11,
          1.8518e-11, 4.4984e-12],
         [2.1499e-12, 2.8341e-14, 5.1407e-13,  ..., 1.1932e-12,
          1.7047e-12, 3.8768e-13]],

        [[7.2608e-11, 3.1996e-12, 2.6656e-11,  ..., 4.0628e-11,
          6.4030e-11, 1.9869e-11],
         [1.3361e-11, 5.7595e-13, 4.5810e-12,  ..., 6.8138e-12,
          1.0547e-11, 3.4780e-12]]], device='cuda:0')
tensor([[ 99, 108],
        [114, 121],
        [ 43,  53]], device='cuda:0')
tensor([[129, 129],
        [129, 129],
        [129, 129]], device='cuda:0')
loss for step 0 : 5.5861485937367314
epoch 16


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[-0.12786865  0.02157593]
 [-0.3451538  -0.7141113 ]
 [ 0.42874146  0.03952026]]
tensor([[[1.2221e-10, 3.3970e-12, 4.1739e-11,  ..., 7.2147e-11,
          9.5787e-11, 2.7454e-11],
         [6.6884e-12, 1.2142e-13, 2.0343e-12,  ..., 3.8310e-12,
          4.9734e-12, 1.2544e-12]],

        [[1.6895e-10, 6.1836e-12, 5.6188e-11,  ..., 1.0618e-10,
          1.3280e-10, 4.5272e-11],
         [6.3577e-12, 1.2696e-13, 1.8329e-12,  ..., 4.0964e-12,
          5.1348e-12, 1.4286e-12]],

        [[1.4868e-10, 4.4808e-12, 4.8903e-11,  ..., 8.9281e-11,
          1.2053e-10, 3.7702e-11],
         [6.8363e-12, 1.0549e-13, 1.9069e-12,  ..., 3.6687e-12,
          4.7248e-12, 1.2493e-12]]], device='cuda:0')
tensor([[111, 130],
        [ 83,  36],
        [182, 133]], device='cuda:0')
tensor([[129, 129],
        [129, 129],
        [129, 129]], device='cuda:0')
loss for step 0 : 5.586130175383195
epoch 17


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[ 0.08331299 -0.29275513]
 [-0.16522217 -0.10839844]
 [-0.07107544  0.11535645]]
tensor([[[5.7071e-12, 1.6323e-13, 1.2524e-12,  ..., 2.3925e-12,
          3.9309e-12, 1.0047e-12],
         [3.1659e-12, 6.3227e-14, 6.7623e-13,  ..., 1.3680e-12,
          2.4622e-12, 5.1586e-13]],

        [[7.7817e-12, 2.4761e-13, 1.6970e-12,  ..., 3.3093e-12,
          5.3374e-12, 1.4765e-12],
         [6.9431e-12, 1.6409e-13, 1.4688e-12,  ..., 2.9588e-12,
          4.8121e-12, 1.1404e-12]],

        [[3.0669e-12, 7.0405e-14, 6.8440e-13,  ..., 1.1286e-12,
          1.8773e-12, 4.3335e-13],
         [2.7540e-12, 4.8759e-14, 5.8082e-13,  ..., 9.8077e-13,
          1.7441e-12, 3.2913e-13]]], device='cuda:0')
tensor([[138,  90],
        [106, 114],
        [118, 142]], device='cuda:0')
tensor([[129, 129],
        [129, 129],
        [129, 129]], device='cuda:0')
loss for step 0 : 5.5861485937367314
epoch 18


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[ 0.05450439  0.34146118]
 [-0.17877197 -0.34765625]
 [ 0.08673096  0.02978516]]
tensor([[[3.1640e-11, 1.5397e-12, 9.1289e-12,  ..., 1.3579e-11,
          2.3613e-11, 7.4325e-12],
         [2.2236e-11, 7.4712e-13, 5.7396e-12,  ..., 9.0099e-12,
          1.5707e-11, 4.2406e-12]],

        [[1.1879e-11, 2.9544e-13, 2.4906e-12,  ..., 4.3035e-12,
          6.7868e-12, 1.7036e-12],
         [7.0386e-12, 1.2564e-13, 1.3005e-12,  ..., 2.5043e-12,
          4.3351e-12, 8.9723e-13]],

        [[9.9028e-12, 2.6067e-13, 2.0561e-12,  ..., 3.9886e-12,
          6.8490e-12, 1.7011e-12],
         [6.5316e-12, 1.2549e-13, 1.2411e-12,  ..., 2.5840e-12,
          4.9353e-12, 1.0362e-12]]], device='cuda:0')
tensor([[134, 171],
        [105,  83],
        [139, 131]], device='cuda:0')
tensor([[129, 129],
        [129, 129],
        [129, 129]], device='cuda:0')
loss for step 0 : 5.586148585443912
epoch 19


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[-0.00427246  0.09274292]
 [ 0.09790039 -0.09713745]
 [ 0.00643921  0.037323  ]]
tensor([[[1.1336e-10, 4.4846e-12, 2.9656e-11,  ..., 5.0024e-11,
          6.6144e-11, 2.3139e-11],
         [8.6100e-11, 3.3523e-12, 2.2654e-11,  ..., 4.0988e-11,
          5.4298e-11, 1.8358e-11]],

        [[2.7638e-10, 1.9060e-11, 7.3857e-11,  ..., 1.3651e-10,
          1.8403e-10, 6.8545e-11],
         [1.2249e-10, 5.4877e-12, 3.2285e-11,  ..., 5.8199e-11,
          8.0153e-11, 2.6508e-11]],

        [[1.0871e-10, 4.7051e-12, 2.6755e-11,  ..., 4.6809e-11,
          6.7930e-11, 2.2558e-11],
         [1.0109e-10, 4.1794e-12, 2.3289e-11,  ..., 4.2124e-11,
          5.8608e-11, 2.0026e-11]]], device='cuda:0')
tensor([[127, 139],
        [140, 115],
        [128, 132]], device='cuda:0')
tensor([[129, 129],
        [129, 129],
        [129, 129]], device='cuda:0')
loss for step 0 : 5.586139537977136
epoch 20


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[ 0.25842285 -0.14764404]
 [-0.6793213   0.0163269 ]
 [ 0.08544922  0.19766235]]
tensor([[[1.8205e-08, 3.9075e-09, 9.8065e-09,  ..., 1.4250e-08,
          2.0147e-08, 9.9522e-09],
         [2.3489e-15, 1.0342e-16, 7.3559e-16,  ..., 1.1294e-15,
          1.8914e-15, 7.3310e-16]],

        [[1.6895e-08, 3.5146e-09, 8.7427e-09,  ..., 1.2333e-08,
          1.7676e-08, 8.9513e-09],
         [2.2330e-15, 1.0138e-16, 6.6517e-16,  ..., 1.1278e-15,
          1.8394e-15, 6.9283e-16]],

        [[1.6817e-08, 3.2365e-09, 8.4054e-09,  ..., 1.1513e-08,
          1.7329e-08, 8.2284e-09],
         [2.2078e-15, 9.3336e-17, 6.4885e-16,  ..., 1.0407e-15,
          1.7743e-15, 6.9123e-16]]], device='cuda:0')
tensor([[161, 109],
        [ 41, 130],
        [138, 153]], device='cuda:0')
tensor([[129, 129],
        [129, 129],
        [129, 129]], device='cuda:0')
loss for step 0 : 5.586148353244948
epoch 21


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[-0.32998657 -0.4463501 ]
 [ 0.33047485  0.55526733]
 [-0.36849976 -0.5513916 ]]
tensor([[[1.5662e-08, 2.9065e-09, 8.2944e-09,  ..., 1.1074e-08,
          1.5754e-08, 7.2567e-09],
         [1.7713e-15, 6.3845e-17, 5.7641e-16,  ..., 8.5322e-16,
          1.4841e-15, 5.0324e-16]],

        [[1.5709e-08, 3.0017e-09, 8.0729e-09,  ..., 1.1184e-08,
          1.6591e-08, 7.6101e-09],
         [1.6541e-15, 6.4350e-17, 5.5872e-16,  ..., 8.2877e-16,
          1.3645e-15, 5.3612e-16]],

        [[1.6480e-08, 2.8709e-09, 8.4818e-09,  ..., 1.1414e-08,
          1.6662e-08, 7.8564e-09],
         [1.8661e-15, 6.8732e-17, 6.0183e-16,  ..., 8.8031e-16,
          1.4669e-15, 5.4129e-16]]], device='cuda:0')
tensor([[ 85,  70],
        [170, 199],
        [ 80,  57]], device='cuda:0')
tensor([[129, 129],
        [129, 129],
        [129, 129]], device='cuda:0')
loss for step 0 : 5.586130162943965
epoch 22


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[-0.04360962  0.19934082]
 [ 0.19662476  0.5199585 ]
 [-0.50772095 -0.36923218]]
tensor([[[2.4378e-09, 4.5053e-10, 8.1594e-10,  ..., 1.7923e-09,
          2.4118e-09, 9.5237e-10],
         [1.2151e-14, 5.1823e-16, 2.0647e-15,  ..., 7.0830e-15,
          9.1210e-15, 2.8601e-15]],

        [[2.4834e-09, 4.3365e-10, 8.2931e-10,  ..., 1.7680e-09,
          2.5467e-09, 9.8794e-10],
         [1.1305e-14, 5.2000e-16, 2.0185e-15,  ..., 6.4671e-15,
          8.5028e-15, 2.7800e-15]],

        [[2.5714e-09, 4.4895e-10, 8.5581e-10,  ..., 1.8916e-09,
          2.5281e-09, 1.0685e-09],
         [1.1568e-14, 5.2886e-16, 2.1451e-15,  ..., 6.8073e-15,
          9.0723e-15, 2.9988e-15]]], device='cuda:0')
tensor([[122, 153],
        [153, 194],
        [ 63,  80]], device='cuda:0')
tensor([[129, 129],
        [129, 129],
        [129, 129]], device='cuda:0')
loss for step 0 : 5.586139256021251
epoch 23


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[ 0.3932495   0.04443359]
 [-0.20062256  0.3748474 ]
 [ 0.6699219   0.41397095]]
tensor([[[3.4642e-09, 6.1789e-10, 1.2049e-09,  ..., 2.5627e-09,
          3.5189e-09, 1.4727e-09],
         [1.4505e-14, 6.8440e-16, 2.8352e-15,  ..., 9.3119e-15,
          1.1397e-14, 3.8026e-15]],

        [[3.6110e-09, 6.3983e-10, 1.2436e-09,  ..., 2.7068e-09,
          3.3995e-09, 1.4775e-09],
         [1.4717e-14, 6.4692e-16, 2.6219e-15,  ..., 8.5331e-15,
          1.0211e-14, 3.5171e-15]],

        [[3.7009e-09, 6.5439e-10, 1.3301e-09,  ..., 2.7583e-09,
          3.5330e-09, 1.4505e-09],
         [1.5585e-14, 6.9848e-16, 2.9228e-15,  ..., 8.8723e-15,
          1.1839e-14, 3.9025e-15]]], device='cuda:0')
tensor([[178, 133],
        [102, 175],
        [213, 180]], device='cuda:0')
tensor([[129, 129],
        [129, 129],
        [129, 129]], device='cuda:0')
loss for step 0 : 5.5861482827559765
epoch 24


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[-0.09008789  0.1777649 ]
 [ 0.11062622 -0.3357544 ]
 [-0.25967407 -0.17453003]]
tensor([[[8.1896e-09, 1.6360e-09, 3.4500e-09,  ..., 7.0216e-09,
          8.4473e-09, 4.0165e-09],
         [3.3447e-14, 1.6974e-15, 7.2819e-15,  ..., 2.5655e-14,
          2.6590e-14, 1.0860e-14]],

        [[8.6484e-09, 1.7186e-09, 3.5140e-09,  ..., 7.3682e-09,
          9.0354e-09, 4.1558e-09],
         [3.2351e-14, 1.7202e-15, 7.2035e-15,  ..., 2.3281e-14,
          2.6894e-14, 1.0216e-14]],

        [[8.7015e-09, 1.6771e-09, 3.3720e-09,  ..., 7.3806e-09,
          8.4881e-09, 4.1039e-09],
         [3.2966e-14, 1.6562e-15, 7.5187e-15,  ..., 2.4129e-14,
          2.6300e-14, 1.0233e-14]]], device='cuda:0')
tensor([[116, 150],
        [142,  85],
        [ 94, 105]], device='cuda:0')
tensor([[129, 129],
        [129, 129],
        [129, 129]], device='cuda:0')
loss for step 0 : 5.586129599032195
epoch 25


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[ 0.103302   -0.19018555]
 [-0.01809692 -0.53411865]
 [-0.0300293  -0.13287354]]
tensor([[[2.3906e-07, 1.1836e-07, 1.5040e-07,  ..., 2.7584e-07,
          2.5608e-07, 1.6534e-07],
         [5.3965e-20, 4.4268e-21, 1.4318e-20,  ..., 9.1777e-20,
          4.2941e-20, 2.0495e-20]],

        [[2.3080e-07, 1.1842e-07, 1.4937e-07,  ..., 2.5453e-07,
          2.5274e-07, 1.5879e-07],
         [4.6833e-20, 4.0860e-21, 1.2409e-20,  ..., 8.2432e-20,
          3.9184e-20, 1.8561e-20]],

        [[2.4914e-07, 1.2258e-07, 1.5102e-07,  ..., 2.8167e-07,
          2.7888e-07, 1.6877e-07],
         [5.3641e-20, 4.1993e-21, 1.2879e-20,  ..., 8.7469e-20,
          4.2068e-20, 1.9366e-20]]], device='cuda:0')
tensor([[141, 103],
        [125,  59],
        [124, 110]], device='cuda:0')
tensor([[129, 129],
        [129, 129],
        [129, 129]], device='cuda:0')
loss for step 0 : 5.586064508686895
epoch 26


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[ 0.03082275 -0.24685669]
 [-0.08218384  0.28363037]
 [-0.02377319 -0.1378479 ]]
tensor([[[1.2254e-04, 9.1994e-05, 1.0566e-04,  ..., 1.2856e-04,
          1.2145e-04, 1.0963e-04],
         [3.1304e-23, 5.9546e-24, 8.0978e-24,  ..., 4.6725e-23,
          1.3126e-23, 1.8871e-23]],

        [[1.2856e-04, 9.6215e-05, 1.0866e-04,  ..., 1.3878e-04,
          1.3034e-04, 1.1471e-04],
         [2.7247e-23, 5.2934e-24, 7.4459e-24,  ..., 4.7384e-23,
          1.3053e-23, 1.8001e-23]],

        [[1.2593e-04, 9.6792e-05, 1.1286e-04,  ..., 1.3405e-04,
          1.2904e-04, 1.1631e-04],
         [2.6967e-23, 4.8160e-24, 6.9790e-24,  ..., 4.0915e-23,
          1.1959e-23, 1.6727e-23]]], device='cuda:0')
tensor([[131,  96],
        [117, 164],
        [124, 110]], device='cuda:0')
tensor([[128, 128],
        [128, 128],
        [128, 128]], device='cuda:0')
loss for step 0 : 5.584008639791738
epoch 27


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[ 0.0435791   0.07873535]
 [ 0.21472168  0.46810913]
 [-0.10675049  0.05996704]]
tensor([[[4.7456e-06, 2.8882e-06, 3.5769e-06,  ..., 4.8475e-06,
          4.0879e-06, 3.7112e-06],
         [1.5435e-14, 4.5104e-15, 7.3897e-15,  ..., 2.0429e-14,
          1.1762e-14, 9.7321e-15]],

        [[5.5440e-06, 3.1599e-06, 4.1068e-06,  ..., 5.6395e-06,
          5.2206e-06, 4.1538e-06],
         [1.6451e-14, 4.9408e-15, 7.4663e-15,  ..., 1.9767e-14,
          1.2209e-14, 1.0494e-14]],

        [[5.2380e-06, 3.2871e-06, 3.9067e-06,  ..., 5.3010e-06,
          4.6381e-06, 4.2017e-06],
         [1.7292e-14, 5.2003e-15, 8.1431e-15,  ..., 2.1862e-14,
          1.4125e-14, 1.1218e-14]]], device='cuda:0')
tensor([[133, 138],
        [155, 187],
        [114, 135]], device='cuda:0')
tensor([[129, 129],
        [129, 129],
        [129, 129]], device='cuda:0')
loss for step 0 : 5.584788144153097
epoch 28


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[-0.11672974 -0.18936157]
 [-0.03182983 -0.09634399]
 [-0.7131958  -0.64663696]]
tensor([[[9.5306e-06, 5.8825e-06, 7.0652e-06,  ..., 8.3440e-06,
          8.1509e-06, 7.4111e-06],
         [1.4659e-14, 5.4324e-15, 8.2447e-15,  ..., 2.5005e-14,
          1.6873e-14, 1.0749e-14]],

        [[1.0622e-05, 7.0592e-06, 8.1331e-06,  ..., 1.0993e-05,
          9.9487e-06, 8.8781e-06],
         [1.3733e-14, 5.4677e-15, 7.6795e-15,  ..., 2.3589e-14,
          1.4530e-14, 1.0083e-14]],

        [[1.0835e-05, 7.0656e-06, 8.3329e-06,  ..., 1.1022e-05,
          9.6756e-06, 8.8738e-06],
         [1.2212e-14, 4.9512e-15, 7.1834e-15,  ..., 2.2999e-14,
          1.5496e-14, 9.2692e-15]]], device='cuda:0')
tensor([[113, 103],
        [123, 115],
        [ 36,  45]], device='cuda:0')
tensor([[128, 129],
        [129, 129],
        [129, 129]], device='cuda:0')
loss for step 0 : 5.584299137281335
epoch 29


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[ 0.19378662  0.638916  ]
 [-0.07131958 -0.22805786]
 [-0.00152588  0.25543213]]
tensor([[[6.0019e-09, 3.2150e-09, 4.2252e-09,  ..., 6.4580e-09,
          5.5168e-09, 4.3353e-09],
         [2.1464e-13, 9.4221e-14, 1.3301e-13,  ..., 3.7633e-13,
          3.0976e-13, 1.4607e-13]],

        [[3.9188e-09, 1.9588e-09, 2.4773e-09,  ..., 3.3642e-09,
          2.9614e-09, 2.7641e-09],
         [1.8650e-13, 8.7054e-14, 1.1898e-13,  ..., 2.9198e-13,
          2.2935e-13, 1.2990e-13]],

        [[1.0859e-08, 6.0424e-09, 7.9421e-09,  ..., 1.5112e-08,
          1.3935e-08, 8.0340e-09],
         [2.0108e-13, 8.4947e-14, 1.2426e-13,  ..., 3.2020e-13,
          2.5949e-13, 1.3520e-13]]], device='cuda:0')
tensor([[152, 209],
        [118,  98],
        [127, 160]], device='cuda:0')
tensor([[129, 129],
        [129, 129],
        [129, 129]], device='cuda:0')
loss for step 0 : 5.585287711931311
epoch 30


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[ 0.49346924  0.3227539 ]
 [-0.02279663 -0.07421875]
 [ 0.07711792 -0.02871704]]
tensor([[[1.5959e-07, 9.2923e-08, 1.1590e-07,  ..., 1.4693e-07,
          1.2430e-07, 1.2255e-07],
         [6.9083e-11, 4.0018e-11, 5.9211e-11,  ..., 1.6421e-10,
          1.7000e-10, 6.2301e-11]],

        [[1.2867e-07, 8.2805e-08, 1.0019e-07,  ..., 1.5941e-07,
          1.4645e-07, 1.0546e-07],
         [6.7134e-11, 2.8799e-11, 3.8552e-11,  ..., 5.0434e-11,
          3.7812e-11, 4.6424e-11]],

        [[1.8286e-07, 1.1758e-07, 1.4287e-07,  ..., 2.1727e-07,
          1.9642e-07, 1.4526e-07],
         [6.8690e-11, 3.7345e-11, 5.8208e-11,  ..., 1.6856e-10,
          1.9397e-10, 6.4079e-11]]], device='cuda:0')
tensor([[191, 169],
        [125, 118],
        [137, 124]], device='cuda:0')
tensor([[128, 129],
        [129, 128],
        [129, 129]], device='cuda:0')
loss for step 0 : 5.585227294590162
epoch 31


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[-0.58740234 -0.49490356]
 [ 0.42410278  0.29690552]
 [-0.03955078 -0.09347534]]
tensor([[[8.6663e-09, 4.2529e-09, 5.2662e-09,  ..., 5.6850e-09,
          4.6740e-09, 5.4493e-09],
         [1.1768e-11, 6.5145e-12, 8.3344e-12,  ..., 1.9266e-11,
          1.7158e-11, 8.3231e-12]],

        [[1.4718e-08, 8.0328e-09, 9.2108e-09,  ..., 1.1269e-08,
          9.7314e-09, 9.4781e-09],
         [8.6435e-12, 4.5047e-12, 6.3945e-12,  ..., 1.7859e-11,
          1.8213e-11, 6.8793e-12]],

        [[8.7604e-09, 4.1136e-09, 5.3562e-09,  ..., 5.3355e-09,
          4.2735e-09, 5.3470e-09],
         [9.4796e-12, 4.9879e-12, 6.9307e-12,  ..., 1.9704e-11,
          2.0986e-11, 7.1513e-12]]], device='cuda:0')
tensor([[ 52,  64],
        [182, 166],
        [122, 116]], device='cuda:0')
tensor([[128, 129],
        [128, 129],
        [128, 129]], device='cuda:0')
loss for step 0 : 5.585837683470353
epoch 32


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[ 0.29309082  0.1383667 ]
 [ 0.10903931  0.70910645]
 [-0.22268677  0.12072754]]
tensor([[[7.4362e-06, 5.5169e-06, 5.5301e-06,  ..., 8.7558e-06,
          9.9769e-06, 5.5697e-06],
         [3.7236e-09, 2.4417e-09, 2.4981e-09,  ..., 3.3780e-09,
          2.9088e-09, 2.8784e-09]],

        [[5.3543e-09, 3.7957e-09, 3.1495e-09,  ..., 2.8350e-09,
          3.6892e-09, 6.3570e-09],
         [2.2473e-08, 1.2506e-08, 1.3453e-08,  ..., 3.0340e-08,
          3.5479e-08, 1.3296e-08]],

        [[9.6919e-06, 6.5601e-06, 6.3321e-06,  ..., 1.1040e-05,
          1.3256e-05, 5.7515e-06],
         [2.0467e-08, 1.1218e-08, 1.1706e-08,  ..., 2.5992e-08,
          3.1085e-08, 1.0968e-08]]], device='cuda:0')
tensor([[165, 145],
        [141, 218],
        [ 99, 143]], device='cuda:0')
tensor([[129, 128],
        [128, 129],
        [129, 129]], device='cuda:0')
loss for step 0 : 5.58573866305144
epoch 33


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[ 0.08569336  0.2013855 ]
 [-0.3776245   0.07043457]
 [ 0.32928467 -0.01519775]]
tensor([[[2.3533e-07, 1.3616e-07, 1.4648e-07,  ..., 2.5398e-07,
          3.4745e-07, 1.2787e-07],
         [3.7188e-07, 3.0893e-07, 2.5381e-07,  ..., 3.4365e-07,
          3.1724e-07, 3.1090e-07]],

        [[1.7122e-07, 1.4110e-07, 1.1377e-07,  ..., 1.2599e-07,
          1.2350e-07, 1.6005e-07],
         [1.3875e-07, 8.0465e-08, 8.8092e-08,  ..., 1.5574e-07,
          1.9789e-07, 8.5596e-08]],

        [[4.5731e-07, 2.6945e-07, 2.7653e-07,  ..., 4.7152e-07,
          6.5328e-07, 2.4026e-07],
         [1.5574e-07, 1.3677e-07, 1.0847e-07,  ..., 1.3783e-07,
          1.3406e-07, 1.5453e-07]]], device='cuda:0')
tensor([[138, 153],
        [ 79, 137],
        [170, 126]], device='cuda:0')
tensor([[129, 128],
        [128, 129],
        [129, 128]], device='cuda:0')
loss for step 0 : 5.58535501231318
epoch 34


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[ 0.11645508  0.38427734]
 [-0.17419434 -0.09985352]
 [-0.27539062 -0.38656616]]
tensor([[[9.0288e-08, 5.3007e-08, 5.6737e-08,  ..., 1.1306e-07,
          1.4947e-07, 5.4478e-08],
         [3.2542e-10, 1.5973e-10, 1.5944e-10,  ..., 5.0303e-10,
          1.6252e-09, 2.1923e-10]],

        [[3.7665e-07, 2.6989e-07, 2.7252e-07,  ..., 4.0666e-07,
          3.6751e-07, 2.9134e-07],
         [9.1136e-09, 4.5933e-09, 5.1358e-09,  ..., 1.1616e-08,
          2.6812e-08, 5.5960e-09]],

        [[2.5796e-08, 1.4580e-08, 1.6484e-08,  ..., 3.7302e-08,
          5.6208e-08, 1.7199e-08],
         [9.2781e-08, 8.1621e-08, 6.4067e-08,  ..., 8.4204e-08,
          7.6823e-08, 9.0427e-08]]], device='cuda:0')
tensor([[142, 177],
        [105, 115],
        [ 92,  78]], device='cuda:0')
tensor([[129, 129],
        [127, 129],
        [129, 128]], device='cuda:0')
loss for step 0 : 5.585051644366721
epoch 35


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[-0.3923645  -0.8734436 ]
 [-0.01446533  0.5761719 ]
 [ 0.08300781 -0.02087402]]
tensor([[[3.3938e-07, 2.3026e-07, 2.7001e-07,  ..., 3.7767e-07,
          3.9070e-07, 2.8973e-07],
         [1.0558e-06, 8.0034e-07, 8.4331e-07,  ..., 1.0403e-06,
          9.7403e-07, 9.6260e-07]],

        [[3.3950e-06, 2.7651e-06, 2.5677e-06,  ..., 3.0684e-06,
          3.0588e-06, 3.0285e-06],
         [4.9544e-10, 3.0667e-10, 2.7759e-10,  ..., 5.3335e-10,
          1.3868e-09, 4.4502e-10]],

        [[5.5118e-08, 3.6803e-08, 3.5471e-08,  ..., 6.4456e-08,
          1.0513e-07, 4.5401e-08],
         [1.7940e-12, 1.0487e-12, 8.1258e-13,  ..., 2.0981e-12,
          1.0411e-11, 1.6570e-12]]], device='cuda:0')
tensor([[ 77,  16],
        [126, 201],
        [138, 125]], device='cuda:0')
tensor([[127, 127],
        [127, 129],
        [129, 129]], device='cuda:0')
loss for step 0 : 5.585367795695429
epoch 36


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[ 0.22079468  0.5149231 ]
 [-0.0383606  -0.14193726]
 [ 0.18127441 -0.19769287]]
tensor([[[2.0866e-19, 1.9003e-18, 1.1140e-19,  ..., 2.2167e-19,
          1.5968e-18, 1.8434e-18],
         [7.2202e-07, 6.8315e-07, 5.3883e-07,  ..., 6.5426e-07,
          5.8963e-07, 7.5161e-07]],

        [[2.0636e-07, 1.3354e-07, 1.5531e-07,  ..., 2.3255e-07,
          2.7191e-07, 1.9184e-07],
         [1.5890e-11, 1.0166e-11, 8.6103e-12,  ..., 1.9635e-11,
          1.4007e-10, 2.1578e-11]],

        [[3.6138e-07, 2.4009e-07, 2.6109e-07,  ..., 3.6332e-07,
          3.4978e-07, 2.8716e-07],
         [9.2092e-07, 8.4515e-07, 6.9728e-07,  ..., 8.6413e-07,
          8.2810e-07, 9.2674e-07]]], device='cuda:0')
tensor([[156, 193],
        [123, 109],
        [151, 102]], device='cuda:0')
tensor([[128, 128],
        [129, 129],
        [127, 127]], device='cuda:0')
loss for step 0 : 5.585076419166897
epoch 37


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[-0.13223267  0.42388916]
 [ 0.40679932  0.35263062]
 [ 0.19015503 -0.03887939]]
tensor([[[7.2105e-08, 4.3808e-08, 4.9397e-08,  ..., 6.8833e-08,
          7.3105e-08, 6.2042e-08],
         [8.0819e-09, 4.7883e-09, 4.9661e-09,  ..., 7.7027e-09,
          1.5401e-08, 6.8812e-09]],

        [[6.8331e-08, 5.0669e-08, 5.2575e-08,  ..., 6.6609e-08,
          5.3337e-08, 6.4776e-08],
         [1.5610e-09, 9.4666e-10, 9.3788e-10,  ..., 1.5606e-09,
          3.3733e-09, 1.3829e-09]],

        [[2.4202e-07, 2.1435e-07, 1.8221e-07,  ..., 2.4282e-07,
          2.1432e-07, 2.4288e-07],
         [4.6198e-07, 3.1967e-07, 3.3311e-07,  ..., 4.3651e-07,
          3.9513e-07, 4.0600e-07]]], device='cuda:0')
tensor([[111, 182],
        [180, 173],
        [152, 123]], device='cuda:0')
tensor([[127, 129],
        [127, 129],
        [127, 127]], device='cuda:0')
loss for step 0 : 5.585780110566512
epoch 38


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[ 0.11956787 -0.34005737]
 [-0.01873779  0.13815308]
 [ 0.21948242 -0.1437378 ]]
tensor([[[8.2920e-08, 5.4275e-08, 5.5769e-08,  ..., 7.7484e-08,
          6.3945e-08, 7.1333e-08],
         [1.3153e-12, 8.1297e-13, 6.0257e-13,  ..., 9.7979e-13,
          5.8768e-12, 1.4628e-12]],

        [[1.2214e-07, 8.7852e-08, 8.3885e-08,  ..., 1.1126e-07,
          9.4910e-08, 1.0999e-07],
         [1.5936e-10, 1.0271e-10, 8.5561e-11,  ..., 1.3659e-10,
          4.6879e-10, 1.8236e-10]],

        [[1.1311e-07, 8.4684e-08, 8.0693e-08,  ..., 1.1477e-07,
          8.9645e-08, 1.0429e-07],
         [2.7347e-14, 1.9297e-14, 1.2275e-14,  ..., 1.8492e-14,
          2.2323e-13, 3.1110e-14]]], device='cuda:0')
tensor([[143,  84],
        [125, 145],
        [156, 109]], device='cuda:0')
tensor([[127, 129],
        [127, 129],
        [127, 129]], device='cuda:0')
loss for step 0 : 5.585544291786525
epoch 39


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[ 0.28463745 -0.45074463]
 [-0.36386108  0.06484985]
 [ 0.0954895   0.02990723]]
tensor([[[1.2465e-07, 7.6027e-08, 7.3899e-08,  ..., 1.1192e-07,
          1.2727e-07, 1.0327e-07],
         [2.0887e-09, 4.7635e-09, 1.5806e-09,  ..., 2.0594e-09,
          3.9973e-09, 3.5461e-09]],

        [[2.2255e-09, 1.4200e-09, 1.1900e-09,  ..., 1.8010e-09,
          4.0996e-09, 2.1951e-09],
         [1.1996e-17, 1.2818e-17, 5.1533e-18,  ..., 9.7781e-18,
          2.7086e-16, 2.0209e-17]],

        [[1.3796e-07, 1.6744e-07, 1.0101e-07,  ..., 1.4408e-07,
          1.4057e-07, 1.6072e-07],
         [3.4018e-06, 2.8116e-06, 2.5798e-06,  ..., 3.4377e-06,
          2.9109e-06, 2.9750e-06]]], device='cuda:0')
tensor([[164,  70],
        [ 81, 136],
        [140, 131]], device='cuda:0')
tensor([[127, 128],
        [129, 129],
        [128, 127]], device='cuda:0')
loss for step 0 : 5.585866384920867
epoch 40


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[ 0.04489136  0.24035645]
 [-0.04360962  0.19934082]
 [-0.2144165  -0.46746826]]
tensor([[[1.6548e-07, 1.7414e-07, 1.2079e-07,  ..., 1.7889e-07,
          1.5596e-07, 1.8643e-07],
         [8.4073e-10, 5.6672e-10, 4.8291e-10,  ..., 7.1078e-10,
          2.0529e-09, 8.9870e-10]],

        [[1.6252e-07, 1.2113e-07, 1.0944e-07,  ..., 1.6301e-07,
          1.3480e-07, 1.3682e-07],
         [3.8437e-07, 2.6560e-07, 2.4431e-07,  ..., 3.3533e-07,
          3.1422e-07, 3.0903e-07]],

        [[3.2538e-07, 3.2833e-07, 2.5156e-07,  ..., 3.5730e-07,
          2.9476e-07, 3.5370e-07],
         [5.1649e-07, 3.6073e-07, 3.4514e-07,  ..., 4.6376e-07,
          3.9532e-07, 4.2165e-07]]], device='cuda:0')
tensor([[133, 158],
        [122, 153],
        [100,  68]], device='cuda:0')
tensor([[128, 129],
        [127, 127],
        [127, 127]], device='cuda:0')
loss for step 0 : 5.58448600769043
epoch 41


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[-0.09161377 -0.17907715]
 [-0.31472778 -0.21292114]
 [ 0.3781128  -0.09909058]]
tensor([[[4.7077e-16, 4.4241e-16, 2.5188e-16,  ..., 3.4994e-16,
          7.2695e-15, 9.1428e-16],
         [1.9975e-07, 1.4020e-07, 1.2064e-07,  ..., 1.7456e-07,
          2.7695e-07, 1.8437e-07]],

        [[2.1803e-07, 1.6966e-07, 1.5540e-07,  ..., 2.1290e-07,
          1.7637e-07, 1.8969e-07],
         [7.2239e-07, 6.9020e-07, 5.1964e-07,  ..., 6.9428e-07,
          6.2488e-07, 6.4608e-07]],

        [[2.1312e-07, 2.3997e-07, 1.6507e-07,  ..., 2.3793e-07,
          2.0336e-07, 2.0544e-07],
         [8.9381e-12, 7.5495e-12, 5.2153e-12,  ..., 8.2237e-12,
          4.2786e-11, 1.2771e-11]]], device='cuda:0')
tensor([[116, 105],
        [ 87, 100],
        [176, 115]], device='cuda:0')
tensor([[129, 129],
        [127, 127],
        [128, 129]], device='cuda:0')
loss for step 0 : 5.585546821096669
epoch 42


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[-0.13711548 -0.034729  ]
 [ 0.04910278 -0.21429443]
 [-0.23828125 -0.28027344]]
tensor([[[1.2359e-07, 8.8677e-08, 8.3381e-08,  ..., 1.2321e-07,
          9.3255e-08, 9.7917e-08],
         [2.2123e-12, 2.1434e-12, 1.4617e-12,  ..., 2.0343e-12,
          1.2818e-11, 3.5703e-12]],

        [[2.9798e-09, 2.2989e-09, 1.8352e-09,  ..., 2.7617e-09,
          5.4575e-09, 3.4357e-09],
         [3.5192e-07, 3.2462e-07, 2.8835e-07,  ..., 4.0643e-07,
          2.9785e-07, 3.2862e-07]],

        [[2.0101e-10, 1.5974e-10, 1.0902e-10,  ..., 1.7469e-10,
          5.2984e-10, 2.3165e-10],
         [1.6433e-08, 1.2152e-08, 9.4154e-09,  ..., 1.4436e-08,
          2.7103e-08, 1.7206e-08]]], device='cuda:0')
tensor([[110, 123],
        [134, 100],
        [ 97,  92]], device='cuda:0')
tensor([[127, 129],
        [129, 127],
        [129, 129]], device='cuda:0')
loss for step 0 : 5.585750741543977
epoch 43


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[ 0.03863525  0.5727844 ]
 [ 0.09814453  0.34472656]
 [ 0.00668335 -0.23291016]]
tensor([[[1.3628e-11, 1.0414e-11, 7.7307e-12,  ..., 1.0929e-11,
          4.5469e-11, 1.7119e-11],
         [6.6770e-16, 7.6757e-16, 4.0270e-16,  ..., 5.4250e-16,
          9.5215e-15, 1.3189e-15]],

        [[5.4668e-11, 1.4779e-10, 4.9534e-11,  ..., 7.3626e-11,
          1.1699e-10, 9.7811e-11],
         [9.1364e-20, 1.3806e-19, 6.0824e-20,  ..., 7.0382e-20,
          3.5881e-18, 2.3573e-19]],

        [[1.8082e-07, 2.0919e-07, 1.5303e-07,  ..., 2.2416e-07,
          1.8444e-07, 1.9738e-07],
         [2.0324e-07, 1.3159e-07, 1.2239e-07,  ..., 1.8456e-07,
          2.4261e-07, 1.8731e-07]]], device='cuda:0')
tensor([[132, 201],
        [140, 172],
        [128,  98]], device='cuda:0')
tensor([[129, 129],
        [128, 129],
        [127, 127]], device='cuda:0')
loss for step 0 : 5.585569771476414
epoch 44


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[ 0.25204468  0.00115967]
 [ 0.12454224  0.08319092]
 [-0.01428223  0.04608154]]
tensor([[[7.0851e-08, 7.8290e-08, 5.8049e-08,  ..., 8.3857e-08,
          6.7557e-08, 8.1993e-08],
         [8.7031e-09, 1.6262e-08, 7.8465e-09,  ..., 1.0923e-08,
          1.3246e-08, 1.3171e-08]],

        [[5.8880e-08, 5.2573e-08, 4.5486e-08,  ..., 6.3498e-08,
          4.8383e-08, 5.3869e-08],
         [5.9372e-07, 5.1633e-07, 4.3536e-07,  ..., 6.0473e-07,
          4.7693e-07, 5.7016e-07]],

        [[6.4167e-08, 7.9637e-08, 5.1687e-08,  ..., 7.9314e-08,
          7.2992e-08, 7.2576e-08],
         [1.8320e-07, 1.2719e-07, 1.2657e-07,  ..., 1.8175e-07,
          1.6393e-07, 1.5906e-07]]], device='cuda:0')
tensor([[160, 128],
        [143, 138],
        [126, 133]], device='cuda:0')
tensor([[127, 128],
        [127, 127],
        [128, 127]], device='cuda:0')
loss for step 0 : 5.585446958956511
epoch 45


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[ 0.19830322  0.29833984]
 [ 0.04602051 -0.20251465]
 [-0.01074219 -0.07440186]]
tensor([[[4.4221e-08, 2.8369e-08, 2.5658e-08,  ..., 3.9929e-08,
          4.3572e-08, 3.8082e-08],
         [2.7716e-07, 1.9669e-07, 1.7303e-07,  ..., 2.6141e-07,
          3.1532e-07, 2.4930e-07]],

        [[1.0987e-08, 1.4900e-08, 8.5144e-09,  ..., 1.2754e-08,
          1.2681e-08, 1.2717e-08],
         [1.6191e-07, 1.0920e-07, 9.7510e-08,  ..., 1.4226e-07,
          1.8636e-07, 1.5902e-07]],

        [[1.6932e-13, 6.8952e-13, 1.7373e-13,  ..., 2.4952e-13,
          5.8219e-13, 3.4664e-13],
         [3.2039e-08, 2.2062e-08, 1.9575e-08,  ..., 2.7507e-08,
          4.5369e-08, 3.0651e-08]]], device='cuda:0')
tensor([[153, 166],
        [133, 102],
        [126, 118]], device='cuda:0')
tensor([[127, 127],
        [128, 127],
        [128, 129]], device='cuda:0')
loss for step 0 : 5.5856859248617425
epoch 46


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[ 0.19589233  0.06411743]
 [-0.1506958  -0.40930176]
 [ 0.02752686 -0.4562378 ]]
tensor([[[1.5819e-07, 1.5456e-07, 1.1379e-07,  ..., 1.8070e-07,
          1.4493e-07, 1.5398e-07],
         [4.8121e-07, 3.7163e-07, 3.4187e-07,  ..., 4.5275e-07,
          3.6463e-07, 4.0611e-07]],

        [[4.9140e-09, 8.6973e-09, 4.1223e-09,  ..., 6.1311e-09,
          7.1671e-09, 6.4823e-09],
         [4.0336e-07, 4.0508e-07, 3.2801e-07,  ..., 4.2981e-07,
          3.8449e-07, 4.1569e-07]],

        [[9.4223e-08, 6.5380e-08, 5.9684e-08,  ..., 8.7026e-08,
          6.9157e-08, 7.5918e-08],
         [2.4783e-11, 2.2555e-11, 1.5609e-11,  ..., 2.1390e-11,
          9.7243e-11, 3.5669e-11]]], device='cuda:0')
tensor([[153, 136],
        [108,  75],
        [131,  69]], device='cuda:0')
tensor([[127, 127],
        [128, 127],
        [127, 129]], device='cuda:0')
loss for step 0 : 5.5846552226854405
epoch 47


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[-0.12319946 -0.60165405]
 [ 0.33605957  0.33059692]
 [-0.23852539 -0.31167603]]
tensor([[[7.8978e-08, 8.8216e-08, 5.9467e-08,  ..., 8.9845e-08,
          8.0078e-08, 8.5551e-08],
         [1.2881e-07, 8.6190e-08, 7.5646e-08,  ..., 1.0981e-07,
          1.4791e-07, 1.1450e-07]],

        [[7.4099e-12, 2.5421e-11, 7.1717e-12,  ..., 1.0296e-11,
          2.0659e-11, 1.3655e-11],
         [3.5813e-07, 3.6121e-07, 2.9434e-07,  ..., 3.8118e-07,
          3.2899e-07, 3.5778e-07]],

        [[1.2142e-08, 8.2394e-09, 6.8347e-09,  ..., 1.0759e-08,
          1.5692e-08, 1.1228e-08],
         [1.1693e-12, 1.0993e-12, 8.1471e-13,  ..., 9.9223e-13,
          6.2307e-12, 1.7411e-12]]], device='cuda:0')
tensor([[112,  50],
        [171, 170],
        [ 97,  88]], device='cuda:0')
tensor([[127, 129],
        [128, 127],
        [129, 129]], device='cuda:0')
loss for step 0 : 5.584916309688403
epoch 48


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[[ 0.00384521  0.07012939]
 [-0.04272461 -0.15814209]
 [ 0.37145996  0.25125122]]
tensor([[[1.8632e-15, 1.1938e-14, 1.9866e-15,  ..., 2.7863e-15,
          9.6954e-15, 4.2470e-15],
         [1.0733e-09, 8.2469e-10, 6.6830e-10,  ..., 9.0312e-10,
          2.4593e-09, 1.2604e-09]],

        [[2.6124e-15, 1.7473e-14, 2.8186e-15,  ..., 3.6668e-15,
          1.3893e-14, 6.0187e-15],
         [7.1296e-08, 5.0433e-08, 4.2611e-08,  ..., 5.8967e-08,
          9.8135e-08, 6.6967e-08]],

        [[9.3968e-11, 7.4003e-11, 5.5198e-11,  ..., 7.6271e-11,
          2.4134e-10, 1.1901e-10],
         [1.8187e-07, 1.2793e-07, 1.1562e-07,  ..., 1.6746e-07,
          1.6983e-07, 1.4967e-07]]], device='cuda:0')
tensor([[128, 136],
        [122, 107],
        [175, 160]], device='cuda:0')
tensor([[128, 129],
        [128, 129],
        [129, 127]], device='cuda:0')
loss for step 0 : 5.585597602180813


In [None]:
#save_path = 'transformer_vid_model2_5.7.pt'
#torch.save(wavenet.state_dict(), save_path)
load_path = 'transformer_vid_model3_5.10.pt'
wavenet.load_state_dict(torch.load(load_path))

In [None]:
class VideoOnlyDataset(Dataset):
    def __init__(self, video_frames, num_frames):
        self.video_frames = torch.tensor(video_frames, dtype=torch.float32).permute(0,3,2,1) # Permute to (N, C, H, W)
        self.num_frames = num_frames

    def __len__(self):
        return len(self.video_frames)

    def __getitem__(self, idx):
        idx += 1
        if idx < self.num_frames:
          num_zeros_needed = self.num_frames - idx
          vid_zeros = torch.zeros(num_zeros_needed, *self.video_frames[0].shape)
          vid = torch.vstack((vid_zeros, self.video_frames[0:idx])).transpose(0,1)
          return vid
        #vid shape example torch.Size([32, 3, 10, 36, 64])
        # aud shape example torch.Size([32, 2, 14710])
        vid = self.video_frames[idx-self.num_frames:idx].transpose(0,1)
        # print('idx = ', idx, ' aud.size = ', aud.shape)
        return vid

In [None]:
### generate audio for video from model
num_frames = 1

start_index = int(len(resized_vid_arr)*0.8) + num_frames
audio_list = []
vid_test_shape = (1,3,10,36,64)
zero_frame = torch.zeros((vid_test_shape[3], vid_test_shape[4]))

aud_per_vid_frame = 1500
audio_start_index = int(len(audio_array)*0.8)
print(audio_start_index)
aud_input_arr = torch.tensor(audio_array[audio_start_index:audio_start_index+(num_frames * aud_per_vid_frame) - 1], dtype=torch.float32).cuda().unsqueeze(0)
print(aud_input_arr.shape)

#aud_input_arr = torch.zeros((1,2,(num_frames * aud_per_vid_frame) - 1)).cuda()
print(aud_input_arr.shape)

input_vid = resized_vid_arr[start_index:]
gen_data = VideoOnlyDataset(input_vid, num_frames = num_frames)
gen_loader = DataLoader(gen_data, batch_size=1, shuffle=False)


#NOTE: GEN NEEDS TO BE FIXED TO ACCOUNT FOR MULTIPLE AUDIO RUNS PER FRAME
wavenet.eval()
for i,vid in enumerate(gen_loader):
    vid = vid.cuda()
    print(vid.shape)
    for j in trange(aud_per_vid_frame):
        with torch.no_grad():
            #print(aud_input_arr.shape)
            audio_output = wavenet(vid.cuda(), aud_input_arr.cuda()).cpu()
            #print(audio_output, audio_output.shape)
            audio_list.append(audio_output.squeeze().cpu().numpy())
            #print(aud_input_arr)
            aud_input_arr = aud_input_arr[:, 1:, :]
            #print(aud_input_arr.shape, audio_output.shape)
            #print(aud_input_arr)
            
            aud_input_arr = torch.cat((aud_input_arr, audio_output.cuda()), 1)
            #print(aud_input_arr)
    if i > 15:
        break
np.array(audio_list).shape

In [None]:
print(len(audio_list))

In [None]:
print(np.array(audio_list).shape)
audio_np = np.array(audio_list).transpose(1,0)
audio_np.shape

In [None]:
'''
audio_cpu = audio_output.cpu()
audio_np = audio_cpu.detach().numpy()
audio_np.shape
'''

In [None]:
path = 'test_transformer_model3_out_5.10_1.raw'
np.save(path, audio_np)

In [None]:
#print(sr)

In [None]:
fps = 30
sr = (audio_array.shape[0]// resized_vid_arr.shape[0])*fps
write('test_transformer_model3_out_5.10_1.mp3', sr, audio_np)