In [3]:
!pip install opencv-python

Collecting opencv-python
  Downloading opencv_python-4.5.4.60-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (60.3 MB)
     |████████████████████████████████| 60.3 MB 53.8 MB/s            
Installing collected packages: opencv-python
Successfully installed opencv-python-4.5.4.60


In [1]:
# OMR Model 
# Goal: recognize images of music excerpts

# Modules
import torch
from torch.autograd import Variable
import numpy as np
import pylab as pl
import torch.nn.init as init
import torch.optim as optim
import torch.nn as nn
import cv2

import matplotlib as mpl

class cnn_model(torch.nn.Module):
    def __init__(self, batch_size):
        super(cnn_model, self).__init__()

        kernel_size = [3,3]

        self.conv1 = nn.Conv2d(1, 16, kernel_size = kernel_size)
        self.batch1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16,32, kernel_size = kernel_size)
        self.batch2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32,64, kernel_size = kernel_size)
        self.batch3 = nn.BatchNorm2d(64)

        self.act = nn.LeakyReLU()
        self.pool = nn.MaxPool2d(2,2)

    def forward(self, x):

        # FORWARD PASS
        x = self.conv1(x)
        x = self.batch1(x)
        x = self.act(x)
        x = self.pool(x)

        x = self.conv2(x)
        x = self.batch2(x)
        x = self.act(x)
        x = self.pool(x)

        x = self.conv3(x)
        x = self.batch3(x)
        x = self.act(x)
        x = self.pool(x)

        output = x

        return x

class rnn_model(torch.nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size):
        super(rnn_model, self).__init__()

        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size

        #self.rnn = nn.LSTMCell(input_size = embed_size, hidden_size = hidden_size)
        self.fc = nn.Linear(hidden_size, vocab_size + 1)

    def forward(self,x, input_size):

        #h0 = torch.zeros(16, x.size(0), self.hidden_size).to(device)
        #c0 = torch.zeros(16, x.size(0), self.hidden_size).to(device)
        
        h0 = torch.zeros(16,self.hidden_size,self.hidden_size)#.to(device)
        c0 = torch.zeros(16, self.hidden_size,self.hidden_size)#.to(device)
        
        self.rnn = nn.LSTMCell(input_size = input_size, hidden_size = self.hidden_size)
        out, _ = self.rnn(x, (h0, c0))


        out = out.reshape(out.shape[0], -1)
        out = self.fc(out)

        return out
   

In [75]:
class BasicRNN(nn.Module):
    def __init__(self, batch_size, n_steps, n_inputs, n_neurons, n_outputs):
        super(BasicRNN, self).__init__()
        
        self.n_neurons = n_neurons
        self.batch_size = batch_size
        self.n_steps = n_steps
        self.n_inputs = n_inputs
        self.n_outputs = n_outputs
        
        self.basic_rnn = nn.RNN(self.n_inputs, self.n_neurons) 
        
        self.FC = nn.Linear(self.n_neurons, self.n_outputs)
        
    def init_hidden(self,):
        # (num_layers, batch_size, n_neurons)
        return (torch.zeros(1, self.batch_size, self.n_neurons))
        
    def forward(self, X):
        # transforms X to dimensions: n_steps X batch_size X n_inputs
        #X = X.permute(1, 0, 2) 
        
        self.batch_size = X.size(1)
        self.hidden = self.init_hidden()
        
        # lstm_out => n_steps, batch_size, n_neurons (hidden states for each time step)
        # self.hidden => 1, batch_size, n_neurons (final state from each lstm_out)
        #lstm_out, self.hidden = self.basic_rnn(X, self.hidden)      
        out, self.hidden = self.basic_rnn(self.n_inputs, self.n_neurons)
        out = self.FC(self.hidden)
        
        return out#.view(-1, self.n_outputs) # batch_size X n_output

In [3]:
import ctc_utils
from primus import CTC_PriMuS

In [4]:
# Data
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
corpus = './Data/package'# PATH
set = 'Data/train.txt' 
vocabulary = 'Data/vocabulary_semantic.txt'  
save_model = './trained_\semantic_model'

primus = CTC_PriMuS(corpus, set, vocabulary, semantic = True, val_split = 0.1)
primus.training_list

Training with 70880 and validating with 7875


['000136122-1_2_1',
 '230003636-1_21_2',
 '000123529-1_1_2',
 '000118332-1_1_2',
 '000135764-1_1_1',
 '000110955-1_1_1',
 '190014525-1_1_1',
 '210000218-1_2_1',
 '000122545-1_1_2',
 '000106165-1_1_1',
 '000115764-1_1_1',
 '190101947-1_1_1',
 '000115976-11_1_1',
 '220014638-1_1_2',
 '000102615-1_1_2',
 '211005421-1_4_1',
 '000127845-1_2_1',
 '190015388-1_1_1',
 '200185762-1_1_1',
 '000120336-1_1_1',
 '190018598-1_1_1',
 '211004611-1_5_1',
 '000102431-1_1_1',
 '000136986-1_1_1',
 '000105383-1_1_1',
 '000140766-1_2_1',
 '000126811-1_1_1',
 '212003679-1_1_1',
 '211007011-1_12_1',
 '230001487-1_1_1',
 '110002343-1_2_1',
 '190003571-1_1_1',
 '100016392-1_1_1',
 '000100153-1_2_1',
 '230005816-1_1_1',
 '000104575-1_2_1',
 '180000107-1_8_1',
 '210097285-1_26_1',
 '211004455-1_2_1',
 '000114468-1_1_2',
 '190001219-1_1_1',
 '190001990-1_1_1',
 '000124874-1_1_1',
 '201004334-1_21_1',
 '000142431-1_1_1',
 '225001058-1_55_1',
 '190012417-1_1_1',
 '230002835-1_3_1',
 '220000595-1_1_1',
 '000136800-1_

In [10]:
import os
os.getcwd()

'/home/myranda/Documents/DSI/ML/OMR'

In [11]:
#IMAGE DEBUGGING
sample_filepath = primus.training_list[0]
sample_fullpath = corpus + '/' + sample_filepath + '/' + sample_filepath
print(sample_fullpath)

# Get image
sample_img = cv2.imread(sample_fullpath + '.png', 0)
print(sample_img.shape)


./Data/package/000118390-1_1_2/000118390-1_1_2
(155, 1639)


In [12]:
# IMAGE DEBUGGING - MPL
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

PATH = './Data/package/' + sample_filepath + '/' + sample_filepath

img = mpimg.imread(PATH + '.png')
print(img)

[[[1. 1. 1. 1.]
  [1. 1. 1. 1.]
  [1. 1. 1. 1.]
  ...
  [1. 1. 1. 1.]
  [1. 1. 1. 1.]
  [1. 1. 1. 1.]]

 [[1. 1. 1. 1.]
  [1. 1. 1. 1.]
  [1. 1. 1. 1.]
  ...
  [1. 1. 1. 1.]
  [1. 1. 1. 1.]
  [1. 1. 1. 1.]]

 [[1. 1. 1. 1.]
  [1. 1. 1. 1.]
  [1. 1. 1. 1.]
  ...
  [1. 1. 1. 1.]
  [1. 1. 1. 1.]
  [1. 1. 1. 1.]]

 ...

 [[1. 1. 1. 1.]
  [1. 1. 1. 1.]
  [1. 1. 1. 1.]
  ...
  [1. 1. 1. 1.]
  [1. 1. 1. 1.]
  [1. 1. 1. 1.]]

 [[1. 1. 1. 1.]
  [1. 1. 1. 1.]
  [1. 1. 1. 1.]
  ...
  [1. 1. 1. 1.]
  [1. 1. 1. 1.]
  [1. 1. 1. 1.]]

 [[1. 1. 1. 1.]
  [1. 1. 1. 1.]
  [1. 1. 1. 1.]
  ...
  [1. 1. 1. 1.]
  [1. 1. 1. 1.]
  [1. 1. 1. 1.]]]


In [None]:
# CROP IMAGES


In [5]:
img_height = 128
max_epochs = 1
dropout = 0.5

batch_size = 16
vocabulary_size = primus.vocabulary_size
model_cnn = cnn_model(batch_size)
model_rnn = rnn_model(embed_size = 512, hidden_size = 512, vocab_size = primus.vocabulary_size)

In [6]:
# Loss and optimizer

learning_rate = 0.001
criterion = torch.nn.CTCLoss()
optimizer_cnn = optim.Adam(model_cnn.parameters(), lr = learning_rate) ## ADD MODEL PARAMS
optimizer_rnn = optim.Adam(model_rnn.parameters(), lr = learning_rate)
optimizer = optim.Adam(list(model_cnn.parameters()) + list(model_rnn.parameters()))

In [62]:
# Default params
# With image height of 128, width will be 1870
params = dict()
params['img_height'] = img_height
params['img_width'] = None
params['batch_size'] = 16
params['img_channels'] = 1
params['conv_blocks'] = 4
params['conv_filter_n'] = [32, 64, 128, 256]
params['conv_filter_size'] = [ [3,3], [3,3], [3,3], [3,3] ]
params['conv_pooling_size'] = [ [2,2], [2,2], [2,2], [2,2] ]
params['rnn_units'] = 512
params['rnn_layers'] = 2
params['vocabulary_size'] = vocabulary_size
params['max_width'] = 1500


In [8]:
# Input shape for CTC loss
input_shape = (None, params['img_height'])

In [27]:
batch['inputs'].shape

(16, 128, 2153, 1)

In [60]:
# Train using model_rnn
for epoch in range(max_epochs):
    train_loss = 0.
    valid_loss = 0.
    
    train_acc = 0.
    valid_acc = 0.
    
    for i in range(0, 70880 + 7875, 16):
        batch = primus.nextBatch(params)

        data = batch['inputs']

        targets = ctc_utils.sparse_tuple_from(batch['targets'])
        
        tensor_data = torch.from_numpy(data)
        print(tensor_data.shape)
        tensor_data_reshape = torch.permute(tensor_data,(0,3, 1, 2))
        
        output = model_cnn(tensor_data_reshape)
        print(output.shape)
        #output_size = 64 * 14 * output.shape[3]
        output_size = output.shape[3]
        # Reshape output for RNN
        output = output.view(output.size(0), output.size(3), -1)
        output = output.permute(0,2,1)
        #features = torch.permute(output, (3, 0, 2, 1))
        #features = torch.reshape(features, (16, features.shape[0], 64 * 14)) # width, batch, features
        output_rnn = model_rnn(output, input_size = 64*14)
        
        #Input and target shape
        input_shape = (None, params['img_height'], tensor_data_reshape.shape[3],1)
        target_shape = batch['seq_lengths']
        
        loss = criterion(output, targets, input_shape, target_shape)

        loss.backward()
        optimizer.step()

        #Calc loss
        train_loss += loss.detach().item()
        train_acc += 0 # ADD ACCURACY
    print(train_loss)

torch.Size([16, 128, 2417, 1])
torch.Size([16, 64, 14, 300])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (14336x300 and 896x2048)

In [None]:
# MAKE BATCH SIZE ALL IMAGES TO KEEP WIDTHS THE SAME
# Train using model_rnn
for epoch in range(max_epochs):
    train_loss = 0.
    valid_loss = 0.
    
    train_acc = 0.
    valid_acc = 0.
    
    for i in range(0, 70880 + 7875, 16):
        batch = primus.nextBatch(params)

        data = batch['inputs']

        targets = ctc_utils.sparse_tuple_from(batch['targets'])
        
        tensor_data = torch.from_numpy(data)
        print(tensor_data.shape)
        tensor_data_reshape = torch.permute(tensor_data,(0,3, 1, 2))
        
        output = model_cnn(tensor_data_reshape)
        print(output.shape)
        #output_size = 64 * 14 * output.shape[3]
        output_size = output.shape[3]
        # Reshape output for RNN
        output = output.view(output.size(0), output.size(3), -1)
        output = output.permute(0,2,1)
        #features = torch.permute(output, (3, 0, 2, 1))
        #features = torch.reshape(features, (16, features.shape[0], 64 * 14)) # width, batch, features
        output_rnn = model_rnn(output, input_size = 64*14)
        
        #Input and target shape
        input_shape = (None, params['img_height'], tensor_data_reshape.shape[3],1)
        target_shape = batch['seq_lengths']
        
        loss = criterion(output, targets, input_shape, target_shape)

        loss.backward()
        optimizer.step()

        #Calc loss
        train_loss += loss.detach().item()
        train_acc += 0 # ADD ACCURACY
    print(train_loss)

In [72]:
14336/896

16.0

In [76]:
# Train using Basic RNN

# Setup
BATCH_SIZE = 16
IMG_HEIGHT = img_height
N_EPOCHS = 1
N_OUTPUTS = vocabulary_size + 1
N_NEURONS = 512
#N_INPUTS = 512
N_INPUTS = 896
model_cnn = cnn_model(BATCH_SIZE)
basic_rnn = BasicRNN(BATCH_SIZE, 1, N_INPUTS, N_NEURONS, N_OUTPUTS)
optimizer = optim.Adam(list(model_cnn.parameters()) + list(basic_rnn.parameters()))
len_data = len(primus.training_list) + len(primus.validation_list)

In [78]:
# Train

for epoch in range(N_EPOCHS):
    train_loss = 0.
    train_acc = 0.
    model_cnn.train()
    basic_rnn.train()
    
    for i in range(0, len_data, BATCH_SIZE):
        # zero parameter gradients
        optimizer.zero_grad()
        
        # reset hidden states
        basic_rnn.hidden = basic_rnn.init_hidden()
        
        # Get inputs
        batch = primus.nextBatch(params)

        data = batch['inputs'] # size (batch, height, width, channels)
        #print(data)
        #print(data.shape)
        max_input_length = data.shape[2]
        
        # list of indices, values, shape
        seq_len = int(batch['seq_lengths'][0])
        targets = ctc_utils.sparse_tuple_from(batch['targets'])
        #print(tuple(targets[2]))
        #print(type(t[0]) for t in targets)
        #targets = torch.sparse_coo_tensor(targets[0], targets[1], tuple(targets[2]))
        targets_0 = torch.as_tensor((targets[0]))
        #print(targets.shape)
        #targets = torch.reshape(targets, (16, 1))
        padded_targets, lengths = ctc_utils.pad_sequences(batch['targets'], maxlen=max_input_length)
        padded_targets_tensor = torch.tensor(padded_targets)
        
        tensor_data = torch.from_numpy(data)
        #print(tensor_data.shape)
        tensor_data_reshape = torch.permute(tensor_data,(0,3, 1, 2))
        
        # forward, backward, optim
        cnn_output = model_cnn(tensor_data_reshape)
        output_size = 64 * 14 * cnn_output.shape[3]
        #print(cnn_output.shape)
        #print(cnn_output[0])
        print(cnn_output[0].shape)
        print(cnn_output[0][0].shape)
        
        # Change shape for rnn
        output = torch.reshape(cnn_output, (cnn_output.shape[3], 16, 64 * 14)) # width, batch, features
        print(output.shape)
        rnn_output = basic_rnn(output)
        print(rnn_output[0].shape)
        print(batch['seq_lengths'])
        
        #rnn_output_reshape = torch.reshape(rnn_output, (cnn_output[0].shape[2], BATCH_SIZE, N_OUTPUTS))
        #rnn_output_reshape = torch.reshape(rnn_output[0], (1, BATCH_SIZE, N_OUTPUTS))
        #rnn_output_reshape = rnn_output[0].view(-1, BATCH_SIZE, N_OUTPUTS)
        
        
        log_probs = nn.functional.log_softmax(rnn_output)
        #Input and target shape
        #print(rnn_output_reshape.shape)
        input_shape = (BATCH_SIZE, params['img_height'], tensor_data_reshape.shape[3],1)
        input_len = tuple([1 for i in range (0, BATCH_SIZE)])
        #print(input_shape)
        target_shape = tuple(int(b) for b in batch['seq_lengths'])
        
        # MUST BE TENSOR, TENSOR, TUPLE, TUPLE OR TENSOR TENSOR TENSOR TENSOR
        #loss = criterion(rnn_output_reshape, padded_targets_tensor, input_len, target_shape)
        #loss = criterion(log_probs, padded_targets_tensor, input_len, target_shape)
        loss = criterion(log_probs, padded_targets_tensor, target_shape, tuple(lengths))
        loss.backward()
        optimizer.step()
        
        train_loss += loss.detach().item()
        train_acc += 0
        print("Loss: %f", train_loss)
        
    #model.eval()
    print('training loss:')
    print(train_loss)
        

torch.Size([64, 14, 230])
torch.Size([14, 230])
torch.Size([230, 16, 896])


AttributeError: 'int' object has no attribute 'size'

In [80]:
type(output)

torch.Tensor

In [263]:
28512/16
# 16 times seq_len * n_inputs

1782.0

In [261]:
128 * data.shape[1]

16384

In [146]:
targets[2].shape

(2,)

In [158]:
torch.as_tensor(tuple(batch['targets']))

ValueError: expected sequence of length 26 at dim 1 (got 20)

In [47]:
padded_targets, lengths = ctc_utils.pad_sequences(batch['targets'], maxlen=125)
len(padded_targets)
lengths

array([25, 36, 24, 31, 21, 18, 27, 31, 15, 26, 26, 18, 17, 15, 18, 38])

In [46]:
torch.tensor(padded_targets).shape

torch.Size([16, 125])

In [168]:
padded_targets_list = [torch.tensor(padded_targets[i]) for i in range(0,len(padded_targets))]
padded_targets_list

[tensor([[1.0000e+01, 2.3400e+02, 1.7790e+03, 1.5990e+03, 0.0000e+00, 1.0180e+03,
          1.0180e+03, 1.0180e+03, 1.0180e+03, 1.6470e+03, 1.4830e+03, 1.2370e+03,
          1.0360e+03, 0.0000e+00, 8.2300e+02, 6.0400e+02, 8.5300e+02, 4.0200e+02,
          1.0180e+03, 6.0400e+02, 4.2600e+02, 1.6180e+03, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00],
         [7.0000e+00, 2.2800e+02, 1.7800e+03, 1.7220e+03, 9.8300e+02, 0.0000e+00,
          9.8300e+02, 9.8300e+02, 0.0000e+00, 3.8100e+02, 5.6100e+02, 7.7900e+02,
          0.0000e+00, 9.9200e+02, 7.9000e+02, 5.5600e+02, 0.0000e+00, 3.7400e+02,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.000

In [50]:
torch.tensor(lengths)

tensor([29, 31, 23, 26, 37, 27, 22, 23, 13, 20, 20, 21, 17, 30, 24, 25])

In [245]:
len(data[0][1])

1743

In [203]:
rnn_output.shape

torch.Size([16, 1782])

In [209]:
len(data[1])

128

In [130]:
sum(target_shape)

1520

In [108]:
# DATA
# num steps: IMAGE WIDTH
# batch size 16
# n_inputs 64 * 14 (from CNN output)
# output of CNN: (64 by 14 by width) - width same across batch

class ImageRNN(nn.Module):
    def __init__(self, batch_size = 16, n_inputs = 896, n_neurons = 512, n_outputs = vocabulary_size +1): # N_ STEPS AFTER BATCH_SIZE
        super(ImageRNN, self).__init__()
        
        self.n_neurons = n_neurons
        #self.batch_size = batch_size
        #self.n_steps = n_steps
        self.n_inputs = n_inputs
        self.n_outputs = n_outputs
        
        self.basic_rnn = nn.RNN(self.n_inputs, self.n_neurons) 
        
        self.FC = nn.Linear(self.n_neurons, self.n_outputs)
        
    def init_hidden(self, batch_size):
        # (num_layers, batch_size, n_neurons)
        return (torch.zeros(1, batch_size, self.n_neurons))
        
    def forward(self, X):
        # transforms X to dimensions: n_steps X batch_size X n_inputs
        #X = X.permute(1, 0, 2) 
        # maybe batch size should be width
        # each batch is 1 by 64 by 14
        
        self.batch_size = X.size(2)
        self.hidden = self.init_hidden(self.batch_size)
        
        # try using a loop - delete this if it breaks
        #lstm_out, self.hidden = self.basic_rnn(X, self.hidden)      
        #out = self.FC(self.hidden)
        out = []
        
        for x in X:
            lstm_out, self.hidden = self.basic_rnn(x, self.hidden)
            out_step = self.FC(self.hidden)
            out.append(out_step)
            
            
        return out#.view(-1, self.n_outputs) # batch_size X n_output
    #output represent log prob of model

In [136]:
# Train using ImageRNN

# SETUP
model_cnn = cnn_model(BATCH_SIZE)
model_rnn = ImageRNN()
optimizer = optim.Adam(list(model_cnn.parameters()) + list(basic_rnn.parameters()))
len_data = len(primus.training_list) + len(primus.validation_list)


for epoch in range(N_EPOCHS):
    train_loss = 0.
    train_acc = 0.
    model_cnn.train()
    model_rnn.train()
    
    for i in range(0, len_data, BATCH_SIZE):
        # zero parameter gradients
        optimizer.zero_grad()
        
        # reset hidden states
        model_rnn.hidden = basic_rnn.init_hidden()
        
        # Get inputs
        batch = primus.nextBatch(params)

        data = batch['inputs'] # size (batch, height, width, channels)
        #print(data)
        #print(data.shape)
        max_input_length = data.shape[2]
        
        # list of indices, values, shape
        seq_len = int(batch['seq_lengths'][0])
        targets = ctc_utils.sparse_tuple_from(batch['targets'])
        #print(tuple(targets[2]))
        #print(type(t[0]) for t in targets)
        #targets = torch.sparse_coo_tensor(targets[0], targets[1], tuple(targets[2]))
        targets_0 = torch.as_tensor((targets[0]))
        #print(targets.shape)
        #targets = torch.reshape(targets, (16, 1))
        padded_targets, lengths = ctc_utils.pad_sequences(batch['targets'], maxlen=max_input_length)
        padded_targets_tensor = torch.tensor(padded_targets)
        
        tensor_data = torch.from_numpy(data)
        #print(tensor_data.shape)
        tensor_data_reshape = torch.permute(tensor_data,(0,3, 1, 2))
        
        # forward, backward, optim
        cnn_output = model_cnn(tensor_data_reshape)
        output_size = 64 * 14 * cnn_output.shape[3]
        #print(cnn_output.shape)
        #print(cnn_output[0])
        print(cnn_output[0].shape)
        print(cnn_output[0][0].shape)
        
        # Change shape for rnn
        #output = torch.reshape(cnn_output, (cnn_output.shape[3], 16, 64 * 14)) # width, batch, features
        
        # SEE IF LOOP OVER OUTPUT
        output = torch.reshape(cnn_output, (16, 1, cnn_output.shape[3], 64*14))
        
        print(output.shape)
        logits = model_rnn(output)
        print(logits[0].shape) #1 by width by vocab size
        print((logits[0][0].shape))
        print(logits[0][0][0].shape) # this is logits for one column of picture
        print(batch['seq_lengths'])
        
        for t in range(0,len(logits)):
            logits[t] = torch.reshape(logits[t],(logits[t].size(1), logits[t].size(2)))
        print(logits[0].shape[0])
        
        batch_size = logits[0].shape[0]
        logits_tensor = torch.stack(logits)
        logits_tensor = torch.permute(logits_tensor, (1,0,2))
        
        #rnn_output_reshape = torch.reshape(rnn_output, (cnn_output[0].shape[2], BATCH_SIZE, N_OUTPUTS))
        #rnn_output_reshape = torch.reshape(rnn_output[0], (1, BATCH_SIZE, N_OUTPUTS))
        #rnn_output_reshape = rnn_output[0].view(-1, BATCH_SIZE, N_OUTPUTS)
        
        
        #log_probs = nn.functional.log_softmax(rnn_output, dim=0)
        #Input and target shape
        #print(rnn_output_reshape.shape)
        input_shape = (BATCH_SIZE, params['img_height'], tensor_data_reshape.shape[3],1)
        input_len = tuple([batch_size for i in range (0, 16)])
        #print(input_shape)
        print(input_len)
        print(logits_tensor.shape)
        target_shape = tuple(int(b) for b in batch['seq_lengths'])
        
        # MUST BE TENSOR, TENSOR, TUPLE, TUPLE OR TENSOR TENSOR TENSOR TENSOR
        #loss = criterion(rnn_output_reshape, padded_targets_tensor, input_len, target_shape)
        #loss = criterion(log_probs, padded_targets_tensor, input_len, target_shape)
        loss = criterion(logits_tensor, padded_targets_tensor, input_len, tuple(lengths))
        loss.backward()
        optimizer.step()
        
        train_loss += loss.detach().item()
        train_acc += 0
        print("Loss: %f", train_loss)
        
    #model.eval()
    print('training loss:')
    print(train_loss)
        

torch.Size([64, 14, 267])
torch.Size([14, 267])
torch.Size([16, 1, 267, 896])
torch.Size([1, 267, 1782])
torch.Size([267, 1782])
torch.Size([1782])
[134.5625 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625
 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625]
267
(267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267)
torch.Size([267, 16, 1782])
Loss: %f -5.809357166290283
torch.Size([64, 14, 204])
torch.Size([14, 204])
torch.Size([16, 1, 204, 896])
torch.Size([1, 204, 1782])
torch.Size([204, 1782])
torch.Size([1782])
[103.3125 103.3125 103.3125 103.3125 103.3125 103.3125 103.3125 103.3125
 103.3125 103.3125 103.3125 103.3125 103.3125 103.3125 103.3125 103.3125]
204
(204, 204, 204, 204, 204, 204, 204, 204, 204, 204, 204, 204, 204, 204, 204, 204)
torch.Size([204, 16, 1782])
Loss: %f -10.957740783691406
torch.Size([64, 14, 315])
torch.Size([14, 315])
torch.Size([16, 1, 315, 896])
torch.Size([1, 315, 1782])
torch.Size([315, 1782]

Loss: %f -109.83339929580688
torch.Size([64, 14, 206])
torch.Size([14, 206])
torch.Size([16, 1, 206, 896])
torch.Size([1, 206, 1782])
torch.Size([206, 1782])
torch.Size([1782])
[104. 104. 104. 104. 104. 104. 104. 104. 104. 104. 104. 104. 104. 104.
 104. 104.]
206
(206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206)
torch.Size([206, 16, 1782])
Loss: %f -114.90282392501831
torch.Size([64, 14, 176])
torch.Size([14, 176])
torch.Size([16, 1, 176, 896])
torch.Size([1, 176, 1782])
torch.Size([176, 1782])
torch.Size([1782])
[89.125 89.125 89.125 89.125 89.125 89.125 89.125 89.125 89.125 89.125
 89.125 89.125 89.125 89.125 89.125 89.125]
176
(176, 176, 176, 176, 176, 176, 176, 176, 176, 176, 176, 176, 176, 176, 176, 176)
torch.Size([176, 16, 1782])
Loss: %f -119.55754661560059
torch.Size([64, 14, 286])
torch.Size([14, 286])
torch.Size([16, 1, 286, 896])
torch.Size([1, 286, 1782])
torch.Size([286, 1782])
torch.Size([1782])
[144.25 144.25 144.25 144.25 144.25 144.25 14

Loss: %f -214.55992221832275
torch.Size([64, 14, 248])
torch.Size([14, 248])
torch.Size([16, 1, 248, 896])
torch.Size([1, 248, 1782])
torch.Size([248, 1782])
torch.Size([1782])
[125.125 125.125 125.125 125.125 125.125 125.125 125.125 125.125 125.125
 125.125 125.125 125.125 125.125 125.125 125.125 125.125]
248
(248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248)
torch.Size([248, 16, 1782])
Loss: %f -219.7144546508789
torch.Size([64, 14, 210])
torch.Size([14, 210])
torch.Size([16, 1, 210, 896])
torch.Size([1, 210, 1782])
torch.Size([210, 1782])
torch.Size([1782])
[106.0625 106.0625 106.0625 106.0625 106.0625 106.0625 106.0625 106.0625
 106.0625 106.0625 106.0625 106.0625 106.0625 106.0625 106.0625 106.0625]
210
(210, 210, 210, 210, 210, 210, 210, 210, 210, 210, 210, 210, 210, 210, 210, 210)
torch.Size([210, 16, 1782])
Loss: %f -224.58364343643188
torch.Size([64, 14, 303])
torch.Size([14, 303])
torch.Size([16, 1, 303, 896])
torch.Size([1, 303, 1782])
torch.Siz

Loss: %f -317.5003023147583
torch.Size([64, 14, 272])
torch.Size([14, 272])
torch.Size([16, 1, 272, 896])
torch.Size([1, 272, 1782])
torch.Size([272, 1782])
torch.Size([1782])
[136.875 136.875 136.875 136.875 136.875 136.875 136.875 136.875 136.875
 136.875 136.875 136.875 136.875 136.875 136.875 136.875]
272
(272, 272, 272, 272, 272, 272, 272, 272, 272, 272, 272, 272, 272, 272, 272, 272)
torch.Size([272, 16, 1782])
Loss: %f -322.7487440109253
torch.Size([64, 14, 168])
torch.Size([14, 168])
torch.Size([16, 1, 168, 896])
torch.Size([1, 168, 1782])
torch.Size([168, 1782])
torch.Size([1782])
[85.3125 85.3125 85.3125 85.3125 85.3125 85.3125 85.3125 85.3125 85.3125
 85.3125 85.3125 85.3125 85.3125 85.3125 85.3125 85.3125]
168
(168, 168, 168, 168, 168, 168, 168, 168, 168, 168, 168, 168, 168, 168, 168, 168)
torch.Size([168, 16, 1782])
Loss: %f -327.0855646133423
torch.Size([64, 14, 258])
torch.Size([14, 258])
torch.Size([16, 1, 258, 896])
torch.Size([1, 258, 1782])
torch.Size([258, 1782])
tor

Loss: %f -415.1536226272583
torch.Size([64, 14, 230])
torch.Size([14, 230])
torch.Size([16, 1, 230, 896])
torch.Size([1, 230, 1782])
torch.Size([230, 1782])
torch.Size([1782])
[116. 116. 116. 116. 116. 116. 116. 116. 116. 116. 116. 116. 116. 116.
 116. 116.]
230
(230, 230, 230, 230, 230, 230, 230, 230, 230, 230, 230, 230, 230, 230, 230, 230)
torch.Size([230, 16, 1782])
Loss: %f -419.9722843170166
torch.Size([64, 14, 222])
torch.Size([14, 222])
torch.Size([16, 1, 222, 896])
torch.Size([1, 222, 1782])
torch.Size([222, 1782])
torch.Size([1782])
[112. 112. 112. 112. 112. 112. 112. 112. 112. 112. 112. 112. 112. 112.
 112. 112.]
222
(222, 222, 222, 222, 222, 222, 222, 222, 222, 222, 222, 222, 222, 222, 222, 222)
torch.Size([222, 16, 1782])
Loss: %f -424.7477436065674
torch.Size([64, 14, 217])
torch.Size([14, 217])
torch.Size([16, 1, 217, 896])
torch.Size([1, 217, 1782])
torch.Size([217, 1782])
torch.Size([1782])
[109.6875 109.6875 109.6875 109.6875 109.6875 109.6875 109.6875 109.6875
 109.68

Loss: %f -512.556670665741
torch.Size([64, 14, 274])
torch.Size([14, 274])
torch.Size([16, 1, 274, 896])
torch.Size([1, 274, 1782])
torch.Size([274, 1782])
torch.Size([1782])
[138.0625 138.0625 138.0625 138.0625 138.0625 138.0625 138.0625 138.0625
 138.0625 138.0625 138.0625 138.0625 138.0625 138.0625 138.0625 138.0625]
274
(274, 274, 274, 274, 274, 274, 274, 274, 274, 274, 274, 274, 274, 274, 274, 274)
torch.Size([274, 16, 1782])
Loss: %f -517.3990430831909
torch.Size([64, 14, 248])
torch.Size([14, 248])
torch.Size([16, 1, 248, 896])
torch.Size([1, 248, 1782])
torch.Size([248, 1782])
torch.Size([1782])
[125.125 125.125 125.125 125.125 125.125 125.125 125.125 125.125 125.125
 125.125 125.125 125.125 125.125 125.125 125.125 125.125]
248
(248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248)
torch.Size([248, 16, 1782])
Loss: %f -522.3455386161804
torch.Size([64, 14, 258])
torch.Size([14, 258])
torch.Size([16, 1, 258, 896])
torch.Size([1, 258, 1782])
torch.Size([

Loss: %f -608.2196850776672
torch.Size([64, 14, 204])
torch.Size([14, 204])
torch.Size([16, 1, 204, 896])
torch.Size([1, 204, 1782])
torch.Size([204, 1782])
torch.Size([1782])
[103.3125 103.3125 103.3125 103.3125 103.3125 103.3125 103.3125 103.3125
 103.3125 103.3125 103.3125 103.3125 103.3125 103.3125 103.3125 103.3125]
204
(204, 204, 204, 204, 204, 204, 204, 204, 204, 204, 204, 204, 204, 204, 204, 204)
torch.Size([204, 16, 1782])
Loss: %f -612.788568019867
torch.Size([64, 14, 220])
torch.Size([14, 220])
torch.Size([16, 1, 220, 896])
torch.Size([1, 220, 1782])
torch.Size([220, 1782])
torch.Size([1782])
[111.25 111.25 111.25 111.25 111.25 111.25 111.25 111.25 111.25 111.25
 111.25 111.25 111.25 111.25 111.25 111.25]
220
(220, 220, 220, 220, 220, 220, 220, 220, 220, 220, 220, 220, 220, 220, 220, 220)
torch.Size([220, 16, 1782])
Loss: %f -617.4920558929443
torch.Size([64, 14, 209])
torch.Size([14, 209])
torch.Size([16, 1, 209, 896])
torch.Size([1, 209, 1782])
torch.Size([209, 1782])
torc

Loss: %f -701.6608023643494
torch.Size([64, 14, 167])
torch.Size([14, 167])
torch.Size([16, 1, 167, 896])
torch.Size([1, 167, 1782])
torch.Size([167, 1782])
torch.Size([1782])
[84.4375 84.4375 84.4375 84.4375 84.4375 84.4375 84.4375 84.4375 84.4375
 84.4375 84.4375 84.4375 84.4375 84.4375 84.4375 84.4375]
167
(167, 167, 167, 167, 167, 167, 167, 167, 167, 167, 167, 167, 167, 167, 167, 167)
torch.Size([167, 16, 1782])
Loss: %f -705.9013743400574
torch.Size([64, 14, 232])
torch.Size([14, 232])
torch.Size([16, 1, 232, 896])
torch.Size([1, 232, 1782])
torch.Size([232, 1782])
torch.Size([1782])
[116.9375 116.9375 116.9375 116.9375 116.9375 116.9375 116.9375 116.9375
 116.9375 116.9375 116.9375 116.9375 116.9375 116.9375 116.9375 116.9375]
232
(232, 232, 232, 232, 232, 232, 232, 232, 232, 232, 232, 232, 232, 232, 232, 232)
torch.Size([232, 16, 1782])
Loss: %f -710.6034526824951
torch.Size([64, 14, 216])
torch.Size([14, 216])
torch.Size([16, 1, 216, 896])
torch.Size([1, 216, 1782])
torch.Size(

Loss: %f -795.5543484687805
torch.Size([64, 14, 267])
torch.Size([14, 267])
torch.Size([16, 1, 267, 896])
torch.Size([1, 267, 1782])
torch.Size([267, 1782])
torch.Size([1782])
[134.5625 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625
 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625]
267
(267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267)
torch.Size([267, 16, 1782])
Loss: %f -800.6080007553101
torch.Size([64, 14, 248])
torch.Size([14, 248])
torch.Size([16, 1, 248, 896])
torch.Size([1, 248, 1782])
torch.Size([248, 1782])
torch.Size([1782])
[125.125 125.125 125.125 125.125 125.125 125.125 125.125 125.125 125.125
 125.125 125.125 125.125 125.125 125.125 125.125 125.125]
248
(248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248)
torch.Size([248, 16, 1782])
Loss: %f -805.5733122825623
torch.Size([64, 14, 210])
torch.Size([14, 210])
torch.Size([16, 1, 210, 896])
torch.Size([1, 210, 1782])
torch.Size(

Loss: %f -891.3083038330078
torch.Size([64, 14, 267])
torch.Size([14, 267])
torch.Size([16, 1, 267, 896])
torch.Size([1, 267, 1782])
torch.Size([267, 1782])
torch.Size([1782])
[134.5625 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625
 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625]
267
(267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267)
torch.Size([267, 16, 1782])
Loss: %f -896.3361268043518
torch.Size([64, 14, 248])
torch.Size([14, 248])
torch.Size([16, 1, 248, 896])
torch.Size([1, 248, 1782])
torch.Size([248, 1782])
torch.Size([1782])
[125.125 125.125 125.125 125.125 125.125 125.125 125.125 125.125 125.125
 125.125 125.125 125.125 125.125 125.125 125.125 125.125]
248
(248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248)
torch.Size([248, 16, 1782])
Loss: %f -901.223283290863
torch.Size([64, 14, 248])
torch.Size([14, 248])
torch.Size([16, 1, 248, 896])
torch.Size([1, 248, 1782])
torch.Size([

Loss: %f -984.3739902973175
torch.Size([64, 14, 267])
torch.Size([14, 267])
torch.Size([16, 1, 267, 896])
torch.Size([1, 267, 1782])
torch.Size([267, 1782])
torch.Size([1782])
[134.5625 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625
 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625]
267
(267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267)
torch.Size([267, 16, 1782])
Loss: %f -989.3880245685577
torch.Size([64, 14, 222])
torch.Size([14, 222])
torch.Size([16, 1, 222, 896])
torch.Size([1, 222, 1782])
torch.Size([222, 1782])
torch.Size([1782])
[112. 112. 112. 112. 112. 112. 112. 112. 112. 112. 112. 112. 112. 112.
 112. 112.]
222
(222, 222, 222, 222, 222, 222, 222, 222, 222, 222, 222, 222, 222, 222, 222, 222)
torch.Size([222, 16, 1782])
Loss: %f -994.2327435016632
torch.Size([64, 14, 172])
torch.Size([14, 172])
torch.Size([16, 1, 172, 896])
torch.Size([1, 172, 1782])
torch.Size([172, 1782])
torch.Size([1782])
[86.9375 86.9375

Loss: %f -1079.0521070957184
torch.Size([64, 14, 267])
torch.Size([14, 267])
torch.Size([16, 1, 267, 896])
torch.Size([1, 267, 1782])
torch.Size([267, 1782])
torch.Size([1782])
[134.5625 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625
 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625]
267
(267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267)
torch.Size([267, 16, 1782])
Loss: %f -1084.0468981266022
torch.Size([64, 14, 267])
torch.Size([14, 267])
torch.Size([16, 1, 267, 896])
torch.Size([1, 267, 1782])
torch.Size([267, 1782])
torch.Size([1782])
[134.5625 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625
 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625]
267
(267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267)
torch.Size([267, 16, 1782])
Loss: %f -1089.1539552211761
torch.Size([64, 14, 252])
torch.Size([14, 252])
torch.Size([16, 1, 252, 896])
torch.Size([1, 252,

Loss: %f -1174.3202469348907
torch.Size([64, 14, 217])
torch.Size([14, 217])
torch.Size([16, 1, 217, 896])
torch.Size([1, 217, 1782])
torch.Size([217, 1782])
torch.Size([1782])
[109.75 109.75 109.75 109.75 109.75 109.75 109.75 109.75 109.75 109.75
 109.75 109.75 109.75 109.75 109.75 109.75]
217
(217, 217, 217, 217, 217, 217, 217, 217, 217, 217, 217, 217, 217, 217, 217, 217)
torch.Size([217, 16, 1782])
Loss: %f -1179.075739145279
torch.Size([64, 14, 267])
torch.Size([14, 267])
torch.Size([16, 1, 267, 896])
torch.Size([1, 267, 1782])
torch.Size([267, 1782])
torch.Size([1782])
[134.5625 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625
 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625]
267
(267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267)
torch.Size([267, 16, 1782])
Loss: %f -1183.9507467746735
torch.Size([64, 14, 246])
torch.Size([14, 246])
torch.Size([16, 1, 246, 896])
torch.Size([1, 246, 1782])
torch.Size([246, 1782])
t

Loss: %f -1269.7977983951569
torch.Size([64, 14, 267])
torch.Size([14, 267])
torch.Size([16, 1, 267, 896])
torch.Size([1, 267, 1782])
torch.Size([267, 1782])
torch.Size([1782])
[134.5625 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625
 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625]
267
(267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267)
torch.Size([267, 16, 1782])
Loss: %f -1274.671216249466
torch.Size([64, 14, 204])
torch.Size([14, 204])
torch.Size([16, 1, 204, 896])
torch.Size([1, 204, 1782])
torch.Size([204, 1782])
torch.Size([1782])
[103.3125 103.3125 103.3125 103.3125 103.3125 103.3125 103.3125 103.3125
 103.3125 103.3125 103.3125 103.3125 103.3125 103.3125 103.3125 103.3125]
204
(204, 204, 204, 204, 204, 204, 204, 204, 204, 204, 204, 204, 204, 204, 204, 204)
torch.Size([204, 16, 1782])
Loss: %f -1279.2319700717926
torch.Size([64, 14, 248])
torch.Size([14, 248])
torch.Size([16, 1, 248, 896])
torch.Size([1, 248, 

Loss: %f -1365.779840707779
torch.Size([64, 14, 204])
torch.Size([14, 204])
torch.Size([16, 1, 204, 896])
torch.Size([1, 204, 1782])
torch.Size([204, 1782])
torch.Size([1782])
[103.3125 103.3125 103.3125 103.3125 103.3125 103.3125 103.3125 103.3125
 103.3125 103.3125 103.3125 103.3125 103.3125 103.3125 103.3125 103.3125]
204
(204, 204, 204, 204, 204, 204, 204, 204, 204, 204, 204, 204, 204, 204, 204, 204)
torch.Size([204, 16, 1782])
Loss: %f -1370.2930133342743
torch.Size([64, 14, 190])
torch.Size([14, 190])
torch.Size([16, 1, 190, 896])
torch.Size([1, 190, 1782])
torch.Size([190, 1782])
torch.Size([1782])
[96. 96. 96. 96. 96. 96. 96. 96. 96. 96. 96. 96. 96. 96. 96. 96.]
190
(190, 190, 190, 190, 190, 190, 190, 190, 190, 190, 190, 190, 190, 190, 190, 190)
torch.Size([190, 16, 1782])
Loss: %f -1374.6385538578033
torch.Size([64, 14, 279])
torch.Size([14, 279])
torch.Size([16, 1, 279, 896])
torch.Size([1, 279, 1782])
torch.Size([279, 1782])
torch.Size([1782])
[140.4375 140.4375 140.4375 140

Loss: %f -1461.7974712848663
torch.Size([64, 14, 269])
torch.Size([14, 269])
torch.Size([16, 1, 269, 896])
torch.Size([1, 269, 1782])
torch.Size([269, 1782])
torch.Size([1782])
[135.6875 135.6875 135.6875 135.6875 135.6875 135.6875 135.6875 135.6875
 135.6875 135.6875 135.6875 135.6875 135.6875 135.6875 135.6875 135.6875]
269
(269, 269, 269, 269, 269, 269, 269, 269, 269, 269, 269, 269, 269, 269, 269, 269)
torch.Size([269, 16, 1782])
Loss: %f -1466.7815334796906
torch.Size([64, 14, 210])
torch.Size([14, 210])
torch.Size([16, 1, 210, 896])
torch.Size([1, 210, 1782])
torch.Size([210, 1782])
torch.Size([1782])
[106.0625 106.0625 106.0625 106.0625 106.0625 106.0625 106.0625 106.0625
 106.0625 106.0625 106.0625 106.0625 106.0625 106.0625 106.0625 106.0625]
210
(210, 210, 210, 210, 210, 210, 210, 210, 210, 210, 210, 210, 210, 210, 210, 210)
torch.Size([210, 16, 1782])
Loss: %f -1471.3941805362701
torch.Size([64, 14, 267])
torch.Size([14, 267])
torch.Size([16, 1, 267, 896])
torch.Size([1, 267,

Loss: %f -1559.1839153766632
torch.Size([64, 14, 248])
torch.Size([14, 248])
torch.Size([16, 1, 248, 896])
torch.Size([1, 248, 1782])
torch.Size([248, 1782])
torch.Size([1782])
[125.125 125.125 125.125 125.125 125.125 125.125 125.125 125.125 125.125
 125.125 125.125 125.125 125.125 125.125 125.125 125.125]
248
(248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248)
torch.Size([248, 16, 1782])
Loss: %f -1564.0926020145416
torch.Size([64, 14, 244])
torch.Size([14, 244])
torch.Size([16, 1, 244, 896])
torch.Size([1, 244, 1782])
torch.Size([244, 1782])
torch.Size([1782])
[123.1875 123.1875 123.1875 123.1875 123.1875 123.1875 123.1875 123.1875
 123.1875 123.1875 123.1875 123.1875 123.1875 123.1875 123.1875 123.1875]
244
(244, 244, 244, 244, 244, 244, 244, 244, 244, 244, 244, 244, 244, 244, 244, 244)
torch.Size([244, 16, 1782])
Loss: %f -1569.0960953235626
torch.Size([64, 14, 248])
torch.Size([14, 248])
torch.Size([16, 1, 248, 896])
torch.Size([1, 248, 1782])
torch.Si

Loss: %f -1655.77254986763
torch.Size([64, 14, 252])
torch.Size([14, 252])
torch.Size([16, 1, 252, 896])
torch.Size([1, 252, 1782])
torch.Size([252, 1782])
torch.Size([1782])
[127.0625 127.0625 127.0625 127.0625 127.0625 127.0625 127.0625 127.0625
 127.0625 127.0625 127.0625 127.0625 127.0625 127.0625 127.0625 127.0625]
252
(252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252)
torch.Size([252, 16, 1782])
Loss: %f -1660.7312195301056
torch.Size([64, 14, 209])
torch.Size([14, 209])
torch.Size([16, 1, 209, 896])
torch.Size([1, 209, 1782])
torch.Size([209, 1782])
torch.Size([1782])
[105.8125 105.8125 105.8125 105.8125 105.8125 105.8125 105.8125 105.8125
 105.8125 105.8125 105.8125 105.8125 105.8125 105.8125 105.8125 105.8125]
209
(209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209)
torch.Size([209, 16, 1782])
Loss: %f -1665.4618999958038
torch.Size([64, 14, 246])
torch.Size([14, 246])
torch.Size([16, 1, 246, 896])
torch.Size([1, 246, 1

torch.Size([64, 14, 214])
torch.Size([14, 214])
torch.Size([16, 1, 214, 896])
torch.Size([1, 214, 1782])
torch.Size([214, 1782])
torch.Size([1782])
[108.1875 108.1875 108.1875 108.1875 108.1875 108.1875 108.1875 108.1875
 108.1875 108.1875 108.1875 108.1875 108.1875 108.1875 108.1875 108.1875]
214
(214, 214, 214, 214, 214, 214, 214, 214, 214, 214, 214, 214, 214, 214, 214, 214)
torch.Size([214, 16, 1782])
Loss: %f -1750.1649825572968
torch.Size([64, 14, 248])
torch.Size([14, 248])
torch.Size([16, 1, 248, 896])
torch.Size([1, 248, 1782])
torch.Size([248, 1782])
torch.Size([1782])
[125.125 125.125 125.125 125.125 125.125 125.125 125.125 125.125 125.125
 125.125 125.125 125.125 125.125 125.125 125.125 125.125]
248
(248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248)
torch.Size([248, 16, 1782])
Loss: %f -1755.0974547863007
torch.Size([64, 14, 252])
torch.Size([14, 252])
torch.Size([16, 1, 252, 896])
torch.Size([1, 252, 1782])
torch.Size([252, 1782])
torch.Size([1

Loss: %f -1840.997614622116
torch.Size([64, 14, 289])
torch.Size([14, 289])
torch.Size([16, 1, 289, 896])
torch.Size([1, 289, 1782])
torch.Size([289, 1782])
torch.Size([1782])
[145.5625 145.5625 145.5625 145.5625 145.5625 145.5625 145.5625 145.5625
 145.5625 145.5625 145.5625 145.5625 145.5625 145.5625 145.5625 145.5625]
289
(289, 289, 289, 289, 289, 289, 289, 289, 289, 289, 289, 289, 289, 289, 289, 289)
torch.Size([289, 16, 1782])
Loss: %f -1846.261744260788
torch.Size([64, 14, 248])
torch.Size([14, 248])
torch.Size([16, 1, 248, 896])
torch.Size([1, 248, 1782])
torch.Size([248, 1782])
torch.Size([1782])
[125.125 125.125 125.125 125.125 125.125 125.125 125.125 125.125 125.125
 125.125 125.125 125.125 125.125 125.125 125.125 125.125]
248
(248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248)
torch.Size([248, 16, 1782])
Loss: %f -1851.3330829143524
torch.Size([64, 14, 246])
torch.Size([14, 246])
torch.Size([16, 1, 246, 896])
torch.Size([1, 246, 1782])
torch.Size

Loss: %f -1935.9260976314545
torch.Size([64, 14, 216])
torch.Size([14, 216])
torch.Size([16, 1, 216, 896])
torch.Size([1, 216, 1782])
torch.Size([216, 1782])
torch.Size([1782])
[108.9375 108.9375 108.9375 108.9375 108.9375 108.9375 108.9375 108.9375
 108.9375 108.9375 108.9375 108.9375 108.9375 108.9375 108.9375 108.9375]
216
(216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216)
torch.Size([216, 16, 1782])
Loss: %f -1940.566914319992
torch.Size([64, 14, 217])
torch.Size([14, 217])
torch.Size([16, 1, 217, 896])
torch.Size([1, 217, 1782])
torch.Size([217, 1782])
torch.Size([1782])
[109.6875 109.6875 109.6875 109.6875 109.6875 109.6875 109.6875 109.6875
 109.6875 109.6875 109.6875 109.6875 109.6875 109.6875 109.6875 109.6875]
217
(217, 217, 217, 217, 217, 217, 217, 217, 217, 217, 217, 217, 217, 217, 217, 217)
torch.Size([217, 16, 1782])
Loss: %f -1945.1947767734528
torch.Size([64, 14, 248])
torch.Size([14, 248])
torch.Size([16, 1, 248, 896])
torch.Size([1, 248, 

Loss: %f -2031.3547270298004
torch.Size([64, 14, 248])
torch.Size([14, 248])
torch.Size([16, 1, 248, 896])
torch.Size([1, 248, 1782])
torch.Size([248, 1782])
torch.Size([1782])
[125.125 125.125 125.125 125.125 125.125 125.125 125.125 125.125 125.125
 125.125 125.125 125.125 125.125 125.125 125.125 125.125]
248
(248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248)
torch.Size([248, 16, 1782])
Loss: %f -2036.0075285434723
torch.Size([64, 14, 248])
torch.Size([14, 248])
torch.Size([16, 1, 248, 896])
torch.Size([1, 248, 1782])
torch.Size([248, 1782])
torch.Size([1782])
[125.125 125.125 125.125 125.125 125.125 125.125 125.125 125.125 125.125
 125.125 125.125 125.125 125.125 125.125 125.125 125.125]
248
(248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248)
torch.Size([248, 16, 1782])
Loss: %f -2040.8574440479279
torch.Size([64, 14, 206])
torch.Size([14, 206])
torch.Size([16, 1, 206, 896])
torch.Size([1, 206, 1782])
torch.Size([206, 1782])


Loss: %f -2126.058282136917
torch.Size([64, 14, 204])
torch.Size([14, 204])
torch.Size([16, 1, 204, 896])
torch.Size([1, 204, 1782])
torch.Size([204, 1782])
torch.Size([1782])
[103.3125 103.3125 103.3125 103.3125 103.3125 103.3125 103.3125 103.3125
 103.3125 103.3125 103.3125 103.3125 103.3125 103.3125 103.3125 103.3125]
204
(204, 204, 204, 204, 204, 204, 204, 204, 204, 204, 204, 204, 204, 204, 204, 204)
torch.Size([204, 16, 1782])
Loss: %f -2130.5711419582367
torch.Size([64, 14, 179])
torch.Size([14, 179])
torch.Size([16, 1, 179, 896])
torch.Size([1, 179, 1782])
torch.Size([179, 1782])
torch.Size([1782])
[90.6875 90.6875 90.6875 90.6875 90.6875 90.6875 90.6875 90.6875 90.6875
 90.6875 90.6875 90.6875 90.6875 90.6875 90.6875 90.6875]
179
(179, 179, 179, 179, 179, 179, 179, 179, 179, 179, 179, 179, 179, 179, 179, 179)
torch.Size([179, 16, 1782])
Loss: %f -2134.956085920334
torch.Size([64, 14, 173])
torch.Size([14, 173])
torch.Size([16, 1, 173, 896])
torch.Size([1, 173, 1782])
torch.Size

Loss: %f -2220.0674850940704
torch.Size([64, 14, 170])
torch.Size([14, 170])
torch.Size([16, 1, 170, 896])
torch.Size([1, 170, 1782])
torch.Size([170, 1782])
torch.Size([1782])
[85.9375 85.9375 85.9375 85.9375 85.9375 85.9375 85.9375 85.9375 85.9375
 85.9375 85.9375 85.9375 85.9375 85.9375 85.9375 85.9375]
170
(170, 170, 170, 170, 170, 170, 170, 170, 170, 170, 170, 170, 170, 170, 170, 170)
torch.Size([170, 16, 1782])
Loss: %f -2224.249628305435
torch.Size([64, 14, 169])
torch.Size([14, 169])
torch.Size([16, 1, 169, 896])
torch.Size([1, 169, 1782])
torch.Size([169, 1782])
torch.Size([1782])
[85.375 85.375 85.375 85.375 85.375 85.375 85.375 85.375 85.375 85.375
 85.375 85.375 85.375 85.375 85.375 85.375]
169
(169, 169, 169, 169, 169, 169, 169, 169, 169, 169, 169, 169, 169, 169, 169, 169)
torch.Size([169, 16, 1782])
Loss: %f -2228.5389173030853
torch.Size([64, 14, 248])
torch.Size([14, 248])
torch.Size([16, 1, 248, 896])
torch.Size([1, 248, 1782])
torch.Size([248, 1782])
torch.Size([1782]

Loss: %f -2314.881445169449
torch.Size([64, 14, 289])
torch.Size([14, 289])
torch.Size([16, 1, 289, 896])
torch.Size([1, 289, 1782])
torch.Size([289, 1782])
torch.Size([1782])
[145.5625 145.5625 145.5625 145.5625 145.5625 145.5625 145.5625 145.5625
 145.5625 145.5625 145.5625 145.5625 145.5625 145.5625 145.5625 145.5625]
289
(289, 289, 289, 289, 289, 289, 289, 289, 289, 289, 289, 289, 289, 289, 289, 289)
torch.Size([289, 16, 1782])
Loss: %f -2320.0994865894318
torch.Size([64, 14, 214])
torch.Size([14, 214])
torch.Size([16, 1, 214, 896])
torch.Size([1, 214, 1782])
torch.Size([214, 1782])
torch.Size([1782])
[108.1875 108.1875 108.1875 108.1875 108.1875 108.1875 108.1875 108.1875
 108.1875 108.1875 108.1875 108.1875 108.1875 108.1875 108.1875 108.1875]
214
(214, 214, 214, 214, 214, 214, 214, 214, 214, 214, 214, 214, 214, 214, 214, 214)
torch.Size([214, 16, 1782])
Loss: %f -2324.7258660793304
torch.Size([64, 14, 252])
torch.Size([14, 252])
torch.Size([16, 1, 252, 896])
torch.Size([1, 252, 

Loss: %f -2410.892429113388
torch.Size([64, 14, 217])
torch.Size([14, 217])
torch.Size([16, 1, 217, 896])
torch.Size([1, 217, 1782])
torch.Size([217, 1782])
torch.Size([1782])
[109.6875 109.6875 109.6875 109.6875 109.6875 109.6875 109.6875 109.6875
 109.6875 109.6875 109.6875 109.6875 109.6875 109.6875 109.6875 109.6875]
217
(217, 217, 217, 217, 217, 217, 217, 217, 217, 217, 217, 217, 217, 217, 217, 217)
torch.Size([217, 16, 1782])
Loss: %f -2415.4342997074127
torch.Size([64, 14, 286])
torch.Size([14, 286])
torch.Size([16, 1, 286, 896])
torch.Size([1, 286, 1782])
torch.Size([286, 1782])
torch.Size([1782])
[144.25 144.25 144.25 144.25 144.25 144.25 144.25 144.25 144.25 144.25
 144.25 144.25 144.25 144.25 144.25 144.25]
286
(286, 286, 286, 286, 286, 286, 286, 286, 286, 286, 286, 286, 286, 286, 286, 286)
torch.Size([286, 16, 1782])
Loss: %f -2420.4736244678497
torch.Size([64, 14, 286])
torch.Size([14, 286])
torch.Size([16, 1, 286, 896])
torch.Size([1, 286, 1782])
torch.Size([286, 1782])
t

Loss: %f -2506.7102444171906
torch.Size([64, 14, 217])
torch.Size([14, 217])
torch.Size([16, 1, 217, 896])
torch.Size([1, 217, 1782])
torch.Size([217, 1782])
torch.Size([1782])
[109.6875 109.6875 109.6875 109.6875 109.6875 109.6875 109.6875 109.6875
 109.6875 109.6875 109.6875 109.6875 109.6875 109.6875 109.6875 109.6875]
217
(217, 217, 217, 217, 217, 217, 217, 217, 217, 217, 217, 217, 217, 217, 217, 217)
torch.Size([217, 16, 1782])
Loss: %f -2511.3310244083405
torch.Size([64, 14, 272])
torch.Size([14, 272])
torch.Size([16, 1, 272, 896])
torch.Size([1, 272, 1782])
torch.Size([272, 1782])
torch.Size([1782])
[137.0625 137.0625 137.0625 137.0625 137.0625 137.0625 137.0625 137.0625
 137.0625 137.0625 137.0625 137.0625 137.0625 137.0625 137.0625 137.0625]
272
(272, 272, 272, 272, 272, 272, 272, 272, 272, 272, 272, 272, 272, 272, 272, 272)
torch.Size([272, 16, 1782])
Loss: %f -2516.3810560703278
torch.Size([64, 14, 252])
torch.Size([14, 252])
torch.Size([16, 1, 252, 896])
torch.Size([1, 252,

Loss: %f -2604.813004732132
torch.Size([64, 14, 267])
torch.Size([14, 267])
torch.Size([16, 1, 267, 896])
torch.Size([1, 267, 1782])
torch.Size([267, 1782])
torch.Size([1782])
[134.5625 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625
 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625]
267
(267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267)
torch.Size([267, 16, 1782])
Loss: %f -2609.8552849292755
torch.Size([64, 14, 248])
torch.Size([14, 248])
torch.Size([16, 1, 248, 896])
torch.Size([1, 248, 1782])
torch.Size([248, 1782])
torch.Size([1782])
[125.125 125.125 125.125 125.125 125.125 125.125 125.125 125.125 125.125
 125.125 125.125 125.125 125.125 125.125 125.125 125.125]
248
(248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248)
torch.Size([248, 16, 1782])
Loss: %f -2614.586662530899
torch.Size([64, 14, 184])
torch.Size([14, 184])
torch.Size([16, 1, 184, 896])
torch.Size([1, 184, 1782])
torch.Size

Loss: %f -2699.1255605220795
torch.Size([64, 14, 252])
torch.Size([14, 252])
torch.Size([16, 1, 252, 896])
torch.Size([1, 252, 1782])
torch.Size([252, 1782])
torch.Size([1782])
[127.0625 127.0625 127.0625 127.0625 127.0625 127.0625 127.0625 127.0625
 127.0625 127.0625 127.0625 127.0625 127.0625 127.0625 127.0625 127.0625]
252
(252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252)
torch.Size([252, 16, 1782])
Loss: %f -2703.9916899204254
torch.Size([64, 14, 228])
torch.Size([14, 228])
torch.Size([16, 1, 228, 896])
torch.Size([1, 228, 1782])
torch.Size([228, 1782])
torch.Size([1782])
[115.1875 115.1875 115.1875 115.1875 115.1875 115.1875 115.1875 115.1875
 115.1875 115.1875 115.1875 115.1875 115.1875 115.1875 115.1875 115.1875]
228
(228, 228, 228, 228, 228, 228, 228, 228, 228, 228, 228, 228, 228, 228, 228, 228)
torch.Size([228, 16, 1782])
Loss: %f -2708.5738599300385
torch.Size([64, 14, 204])
torch.Size([14, 204])
torch.Size([16, 1, 204, 896])
torch.Size([1, 204,

Loss: %f -2792.326601266861
torch.Size([64, 14, 279])
torch.Size([14, 279])
torch.Size([16, 1, 279, 896])
torch.Size([1, 279, 1782])
torch.Size([279, 1782])
torch.Size([1782])
[140.4375 140.4375 140.4375 140.4375 140.4375 140.4375 140.4375 140.4375
 140.4375 140.4375 140.4375 140.4375 140.4375 140.4375 140.4375 140.4375]
279
(279, 279, 279, 279, 279, 279, 279, 279, 279, 279, 279, 279, 279, 279, 279, 279)
torch.Size([279, 16, 1782])
Loss: %f -2797.5382940769196
torch.Size([64, 14, 232])
torch.Size([14, 232])
torch.Size([16, 1, 232, 896])
torch.Size([1, 232, 1782])
torch.Size([232, 1782])
torch.Size([1782])
[116.875 116.875 116.875 116.875 116.875 116.875 116.875 116.875 116.875
 116.875 116.875 116.875 116.875 116.875 116.875 116.875]
232
(232, 232, 232, 232, 232, 232, 232, 232, 232, 232, 232, 232, 232, 232, 232, 232)
torch.Size([232, 16, 1782])
Loss: %f -2802.3585011959076
torch.Size([64, 14, 216])
torch.Size([14, 216])
torch.Size([16, 1, 216, 896])
torch.Size([1, 216, 1782])
torch.Siz

Loss: %f -2888.896350622177
torch.Size([64, 14, 248])
torch.Size([14, 248])
torch.Size([16, 1, 248, 896])
torch.Size([1, 248, 1782])
torch.Size([248, 1782])
torch.Size([1782])
[125.125 125.125 125.125 125.125 125.125 125.125 125.125 125.125 125.125
 125.125 125.125 125.125 125.125 125.125 125.125 125.125]
248
(248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248)
torch.Size([248, 16, 1782])
Loss: %f -2893.7951896190643
torch.Size([64, 14, 244])
torch.Size([14, 244])
torch.Size([16, 1, 244, 896])
torch.Size([1, 244, 1782])
torch.Size([244, 1782])
torch.Size([1782])
[123.1875 123.1875 123.1875 123.1875 123.1875 123.1875 123.1875 123.1875
 123.1875 123.1875 123.1875 123.1875 123.1875 123.1875 123.1875 123.1875]
244
(244, 244, 244, 244, 244, 244, 244, 244, 244, 244, 244, 244, 244, 244, 244, 244)
torch.Size([244, 16, 1782])
Loss: %f -2898.6233274936676
torch.Size([64, 14, 199])
torch.Size([14, 199])
torch.Size([16, 1, 199, 896])
torch.Size([1, 199, 1782])
torch.Siz

Loss: %f -2985.3656475543976
torch.Size([64, 14, 248])
torch.Size([14, 248])
torch.Size([16, 1, 248, 896])
torch.Size([1, 248, 1782])
torch.Size([248, 1782])
torch.Size([1782])
[125.125 125.125 125.125 125.125 125.125 125.125 125.125 125.125 125.125
 125.125 125.125 125.125 125.125 125.125 125.125 125.125]
248
(248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248)
torch.Size([248, 16, 1782])
Loss: %f -2990.2418162822723
torch.Size([64, 14, 312])
torch.Size([14, 312])
torch.Size([16, 1, 312, 896])
torch.Size([1, 312, 1782])
torch.Size([312, 1782])
torch.Size([1782])
[157. 157. 157. 157. 157. 157. 157. 157. 157. 157. 157. 157. 157. 157.
 157. 157.]
312
(312, 312, 312, 312, 312, 312, 312, 312, 312, 312, 312, 312, 312, 312, 312, 312)
torch.Size([312, 16, 1782])
Loss: %f -2995.470979452133
torch.Size([64, 14, 246])
torch.Size([14, 246])
torch.Size([16, 1, 246, 896])
torch.Size([1, 246, 1782])
torch.Size([246, 1782])
torch.Size([1782])
[124.125 124.125 124.125 124.1

Loss: %f -3081.5087325572968
torch.Size([64, 14, 229])
torch.Size([14, 229])
torch.Size([16, 1, 229, 896])
torch.Size([1, 229, 1782])
torch.Size([229, 1782])
torch.Size([1782])
[115.6875 115.6875 115.6875 115.6875 115.6875 115.6875 115.6875 115.6875
 115.6875 115.6875 115.6875 115.6875 115.6875 115.6875 115.6875 115.6875]
229
(229, 229, 229, 229, 229, 229, 229, 229, 229, 229, 229, 229, 229, 229, 229, 229)
torch.Size([229, 16, 1782])
Loss: %f -3086.142194032669
torch.Size([64, 14, 254])
torch.Size([14, 254])
torch.Size([16, 1, 254, 896])
torch.Size([1, 254, 1782])
torch.Size([254, 1782])
torch.Size([1782])
[128.125 128.125 128.125 128.125 128.125 128.125 128.125 128.125 128.125
 128.125 128.125 128.125 128.125 128.125 128.125 128.125]
254
(254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254)
torch.Size([254, 16, 1782])
Loss: %f -3091.086818933487
torch.Size([64, 14, 265])
torch.Size([14, 265])
torch.Size([16, 1, 265, 896])
torch.Size([1, 265, 1782])
torch.Size

Loss: %f -3178.9981820583344
torch.Size([64, 14, 230])
torch.Size([14, 230])
torch.Size([16, 1, 230, 896])
torch.Size([1, 230, 1782])
torch.Size([230, 1782])
torch.Size([1782])
[116. 116. 116. 116. 116. 116. 116. 116. 116. 116. 116. 116. 116. 116.
 116. 116.]
230
(230, 230, 230, 230, 230, 230, 230, 230, 230, 230, 230, 230, 230, 230, 230, 230)
torch.Size([230, 16, 1782])
Loss: %f -3183.6208341121674
torch.Size([64, 14, 248])
torch.Size([14, 248])
torch.Size([16, 1, 248, 896])
torch.Size([1, 248, 1782])
torch.Size([248, 1782])
torch.Size([1782])
[125.125 125.125 125.125 125.125 125.125 125.125 125.125 125.125 125.125
 125.125 125.125 125.125 125.125 125.125 125.125 125.125]
248
(248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248)
torch.Size([248, 16, 1782])
Loss: %f -3188.481274366379
torch.Size([64, 14, 252])
torch.Size([14, 252])
torch.Size([16, 1, 252, 896])
torch.Size([1, 252, 1782])
torch.Size([252, 1782])
torch.Size([1782])
[127.0625 127.0625 127.0625 12

Loss: %f -3272.2253925800323
torch.Size([64, 14, 315])
torch.Size([14, 315])
torch.Size([16, 1, 315, 896])
torch.Size([1, 315, 1782])
torch.Size([315, 1782])
torch.Size([1782])
[158.5625 158.5625 158.5625 158.5625 158.5625 158.5625 158.5625 158.5625
 158.5625 158.5625 158.5625 158.5625 158.5625 158.5625 158.5625 158.5625]
315
(315, 315, 315, 315, 315, 315, 315, 315, 315, 315, 315, 315, 315, 315, 315, 315)
torch.Size([315, 16, 1782])
Loss: %f -3277.7059795856476
torch.Size([64, 14, 286])
torch.Size([14, 286])
torch.Size([16, 1, 286, 896])
torch.Size([1, 286, 1782])
torch.Size([286, 1782])
torch.Size([1782])
[144.25 144.25 144.25 144.25 144.25 144.25 144.25 144.25 144.25 144.25
 144.25 144.25 144.25 144.25 144.25 144.25]
286
(286, 286, 286, 286, 286, 286, 286, 286, 286, 286, 286, 286, 286, 286, 286, 286)
torch.Size([286, 16, 1782])
Loss: %f -3282.862378835678
torch.Size([64, 14, 233])
torch.Size([14, 233])
torch.Size([16, 1, 233, 896])
torch.Size([1, 233, 1782])
torch.Size([233, 1782])
t

Loss: %f -3368.3217499256134
torch.Size([64, 14, 248])
torch.Size([14, 248])
torch.Size([16, 1, 248, 896])
torch.Size([1, 248, 1782])
torch.Size([248, 1782])
torch.Size([1782])
[125.125 125.125 125.125 125.125 125.125 125.125 125.125 125.125 125.125
 125.125 125.125 125.125 125.125 125.125 125.125 125.125]
248
(248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248)
torch.Size([248, 16, 1782])
Loss: %f -3373.2783086299896
torch.Size([64, 14, 186])
torch.Size([14, 186])
torch.Size([16, 1, 186, 896])
torch.Size([1, 186, 1782])
torch.Size([186, 1782])
torch.Size([1782])
[93.9375 93.9375 93.9375 93.9375 93.9375 93.9375 93.9375 93.9375 93.9375
 93.9375 93.9375 93.9375 93.9375 93.9375 93.9375 93.9375]
186
(186, 186, 186, 186, 186, 186, 186, 186, 186, 186, 186, 186, 186, 186, 186, 186)
torch.Size([186, 16, 1782])
Loss: %f -3377.7448332309723
torch.Size([64, 14, 232])
torch.Size([14, 232])
torch.Size([16, 1, 232, 896])
torch.Size([1, 232, 1782])
torch.Size([232, 1782])


Loss: %f -3463.336759328842
torch.Size([64, 14, 214])
torch.Size([14, 214])
torch.Size([16, 1, 214, 896])
torch.Size([1, 214, 1782])
torch.Size([214, 1782])
torch.Size([1782])
[108.1875 108.1875 108.1875 108.1875 108.1875 108.1875 108.1875 108.1875
 108.1875 108.1875 108.1875 108.1875 108.1875 108.1875 108.1875 108.1875]
214
(214, 214, 214, 214, 214, 214, 214, 214, 214, 214, 214, 214, 214, 214, 214, 214)
torch.Size([214, 16, 1782])
Loss: %f -3467.9580295085907
torch.Size([64, 14, 248])
torch.Size([14, 248])
torch.Size([16, 1, 248, 896])
torch.Size([1, 248, 1782])
torch.Size([248, 1782])
torch.Size([1782])
[125.125 125.125 125.125 125.125 125.125 125.125 125.125 125.125 125.125
 125.125 125.125 125.125 125.125 125.125 125.125 125.125]
248
(248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248)
torch.Size([248, 16, 1782])
Loss: %f -3472.9064242839813
torch.Size([64, 14, 237])
torch.Size([14, 237])
torch.Size([16, 1, 237, 896])
torch.Size([1, 237, 1782])
torch.Siz

Loss: %f -3558.1651895046234
torch.Size([64, 14, 248])
torch.Size([14, 248])
torch.Size([16, 1, 248, 896])
torch.Size([1, 248, 1782])
torch.Size([248, 1782])
torch.Size([1782])
[125.125 125.125 125.125 125.125 125.125 125.125 125.125 125.125 125.125
 125.125 125.125 125.125 125.125 125.125 125.125 125.125]
248
(248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248)
torch.Size([248, 16, 1782])
Loss: %f -3563.2329547405243
torch.Size([64, 14, 346])
torch.Size([14, 346])
torch.Size([16, 1, 346, 896])
torch.Size([1, 346, 1782])
torch.Size([346, 1782])
torch.Size([1782])
[174.0625 174.0625 174.0625 174.0625 174.0625 174.0625 174.0625 174.0625
 174.0625 174.0625 174.0625 174.0625 174.0625 174.0625 174.0625 174.0625]
346
(346, 346, 346, 346, 346, 346, 346, 346, 346, 346, 346, 346, 346, 346, 346, 346)
torch.Size([346, 16, 1782])
Loss: %f -3568.6882798671722
torch.Size([64, 14, 246])
torch.Size([14, 246])
torch.Size([16, 1, 246, 896])
torch.Size([1, 246, 1782])
torch.Si

Loss: %f -3654.9353806972504
torch.Size([64, 14, 156])
torch.Size([14, 156])
torch.Size([16, 1, 156, 896])
torch.Size([1, 156, 1782])
torch.Size([156, 1782])
torch.Size([1782])
[79.0625 79.0625 79.0625 79.0625 79.0625 79.0625 79.0625 79.0625 79.0625
 79.0625 79.0625 79.0625 79.0625 79.0625 79.0625 79.0625]
156
(156, 156, 156, 156, 156, 156, 156, 156, 156, 156, 156, 156, 156, 156, 156, 156)
torch.Size([156, 16, 1782])
Loss: %f -3658.989548444748
torch.Size([64, 14, 198])
torch.Size([14, 198])
torch.Size([16, 1, 198, 896])
torch.Size([1, 198, 1782])
torch.Size([198, 1782])
torch.Size([1782])
[100.1875 100.1875 100.1875 100.1875 100.1875 100.1875 100.1875 100.1875
 100.1875 100.1875 100.1875 100.1875 100.1875 100.1875 100.1875 100.1875]
198
(198, 198, 198, 198, 198, 198, 198, 198, 198, 198, 198, 198, 198, 198, 198, 198)
torch.Size([198, 16, 1782])
Loss: %f -3663.391349554062
torch.Size([64, 14, 252])
torch.Size([14, 252])
torch.Size([16, 1, 252, 896])
torch.Size([1, 252, 1782])
torch.Size

Loss: %f -3749.052880525589
torch.Size([64, 14, 207])
torch.Size([14, 207])
torch.Size([16, 1, 207, 896])
torch.Size([1, 207, 1782])
torch.Size([207, 1782])
torch.Size([1782])
[104.6875 104.6875 104.6875 104.6875 104.6875 104.6875 104.6875 104.6875
 104.6875 104.6875 104.6875 104.6875 104.6875 104.6875 104.6875 104.6875]
207
(207, 207, 207, 207, 207, 207, 207, 207, 207, 207, 207, 207, 207, 207, 207, 207)
torch.Size([207, 16, 1782])
Loss: %f -3753.7290766239166
torch.Size([64, 14, 187])
torch.Size([14, 187])
torch.Size([16, 1, 187, 896])
torch.Size([1, 187, 1782])
torch.Size([187, 1782])
torch.Size([1782])
[94.5 94.5 94.5 94.5 94.5 94.5 94.5 94.5 94.5 94.5 94.5 94.5 94.5 94.5
 94.5 94.5]
187
(187, 187, 187, 187, 187, 187, 187, 187, 187, 187, 187, 187, 187, 187, 187, 187)
torch.Size([187, 16, 1782])
Loss: %f -3758.0537555217743
torch.Size([64, 14, 248])
torch.Size([14, 248])
torch.Size([16, 1, 248, 896])
torch.Size([1, 248, 1782])
torch.Size([248, 1782])
torch.Size([1782])
[125.125 125.1

Loss: %f -3844.9360682964325
torch.Size([64, 14, 248])
torch.Size([14, 248])
torch.Size([16, 1, 248, 896])
torch.Size([1, 248, 1782])
torch.Size([248, 1782])
torch.Size([1782])
[125.125 125.125 125.125 125.125 125.125 125.125 125.125 125.125 125.125
 125.125 125.125 125.125 125.125 125.125 125.125 125.125]
248
(248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248)
torch.Size([248, 16, 1782])
Loss: %f -3849.622653722763
torch.Size([64, 14, 250])
torch.Size([14, 250])
torch.Size([16, 1, 250, 896])
torch.Size([1, 250, 1782])
torch.Size([250, 1782])
torch.Size([1782])
[126.0625 126.0625 126.0625 126.0625 126.0625 126.0625 126.0625 126.0625
 126.0625 126.0625 126.0625 126.0625 126.0625 126.0625 126.0625 126.0625]
250
(250, 250, 250, 250, 250, 250, 250, 250, 250, 250, 250, 250, 250, 250, 250, 250)
torch.Size([250, 16, 1782])
Loss: %f -3854.421322584152
torch.Size([64, 14, 232])
torch.Size([14, 232])
torch.Size([16, 1, 232, 896])
torch.Size([1, 232, 1782])
torch.Size

Loss: %f -3940.9987227916718
torch.Size([64, 14, 248])
torch.Size([14, 248])
torch.Size([16, 1, 248, 896])
torch.Size([1, 248, 1782])
torch.Size([248, 1782])
torch.Size([1782])
[125.125 125.125 125.125 125.125 125.125 125.125 125.125 125.125 125.125
 125.125 125.125 125.125 125.125 125.125 125.125 125.125]
248
(248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248, 248)
torch.Size([248, 16, 1782])
Loss: %f -3945.9420535564423
torch.Size([64, 14, 267])
torch.Size([14, 267])
torch.Size([16, 1, 267, 896])
torch.Size([1, 267, 1782])
torch.Size([267, 1782])
torch.Size([1782])
[134.5625 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625
 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625 134.5625]
267
(267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267, 267)
torch.Size([267, 16, 1782])
Loss: %f -3951.0159108638763
torch.Size([64, 14, 248])
torch.Size([14, 248])
torch.Size([16, 1, 248, 896])
torch.Size([1, 248, 1782])
torch.Si

Loss: %f -4037.722229242325
torch.Size([64, 14, 224])
torch.Size([14, 224])
torch.Size([16, 1, 224, 896])
torch.Size([1, 224, 1782])
torch.Size([224, 1782])
torch.Size([1782])
[113.1875 113.1875 113.1875 113.1875 113.1875 113.1875 113.1875 113.1875
 113.1875 113.1875 113.1875 113.1875 113.1875 113.1875 113.1875 113.1875]
224
(224, 224, 224, 224, 224, 224, 224, 224, 224, 224, 224, 224, 224, 224, 224, 224)
torch.Size([224, 16, 1782])
Loss: %f -4042.5210058689117
torch.Size([64, 14, 252])
torch.Size([14, 252])
torch.Size([16, 1, 252, 896])
torch.Size([1, 252, 1782])
torch.Size([252, 1782])
torch.Size([1782])
[127.0625 127.0625 127.0625 127.0625 127.0625 127.0625 127.0625 127.0625
 127.0625 127.0625 127.0625 127.0625 127.0625 127.0625 127.0625 127.0625]
252
(252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252)
torch.Size([252, 16, 1782])
Loss: %f -4047.3196079730988
torch.Size([64, 14, 202])
torch.Size([14, 202])
torch.Size([16, 1, 202, 896])
torch.Size([1, 202, 

Loss: %f -4134.556245088577
torch.Size([64, 14, 199])
torch.Size([14, 199])
torch.Size([16, 1, 199, 896])
torch.Size([1, 199, 1782])
torch.Size([199, 1782])
torch.Size([1782])
[100.6875 100.6875 100.6875 100.6875 100.6875 100.6875 100.6875 100.6875
 100.6875 100.6875 100.6875 100.6875 100.6875 100.6875 100.6875 100.6875]
199
(199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199)
torch.Size([199, 16, 1782])
Loss: %f -4138.983684301376
torch.Size([64, 14, 230])
torch.Size([14, 230])
torch.Size([16, 1, 230, 896])
torch.Size([1, 230, 1782])
torch.Size([230, 1782])
torch.Size([1782])
[116. 116. 116. 116. 116. 116. 116. 116. 116. 116. 116. 116. 116. 116.
 116. 116.]
230
(230, 230, 230, 230, 230, 230, 230, 230, 230, 230, 230, 230, 230, 230, 230, 230)
torch.Size([230, 16, 1782])
Loss: %f -4143.6697561740875
torch.Size([64, 14, 217])
torch.Size([14, 217])
torch.Size([16, 1, 217, 896])
torch.Size([1, 217, 1782])
torch.Size([217, 1782])
torch.Size([1782])
[109.6875 109.6

Loss: %f -4231.725568532944
torch.Size([64, 14, 331])
torch.Size([14, 331])
torch.Size([16, 1, 331, 896])
torch.Size([1, 331, 1782])
torch.Size([331, 1782])
torch.Size([1782])
[166.8125 166.8125 166.8125 166.8125 166.8125 166.8125 166.8125 166.8125
 166.8125 166.8125 166.8125 166.8125 166.8125 166.8125 166.8125 166.8125]
331
(331, 331, 331, 331, 331, 331, 331, 331, 331, 331, 331, 331, 331, 331, 331, 331)
torch.Size([331, 16, 1782])
Loss: %f -4237.085881948471
torch.Size([64, 14, 346])
torch.Size([14, 346])
torch.Size([16, 1, 346, 896])
torch.Size([1, 346, 1782])
torch.Size([346, 1782])
torch.Size([1782])
[174.0625 174.0625 174.0625 174.0625 174.0625 174.0625 174.0625 174.0625
 174.0625 174.0625 174.0625 174.0625 174.0625 174.0625 174.0625 174.0625]
346
(346, 346, 346, 346, 346, 346, 346, 346, 346, 346, 346, 346, 346, 346, 346, 346)
torch.Size([346, 16, 1782])
Loss: %f -4242.59024310112
torch.Size([64, 14, 267])
torch.Size([14, 267])
torch.Size([16, 1, 267, 896])
torch.Size([1, 267, 178

Loss: %f -4328.50306725502
torch.Size([64, 14, 206])
torch.Size([14, 206])
torch.Size([16, 1, 206, 896])
torch.Size([1, 206, 1782])
torch.Size([206, 1782])
torch.Size([1782])
[104. 104. 104. 104. 104. 104. 104. 104. 104. 104. 104. 104. 104. 104.
 104. 104.]
206
(206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206)
torch.Size([206, 16, 1782])
Loss: %f -4332.994424581528
torch.Size([64, 14, 207])
torch.Size([14, 207])
torch.Size([16, 1, 207, 896])
torch.Size([1, 207, 1782])
torch.Size([207, 1782])
torch.Size([1782])
[104.625 104.625 104.625 104.625 104.625 104.625 104.625 104.625 104.625
 104.625 104.625 104.625 104.625 104.625 104.625 104.625]
207
(207, 207, 207, 207, 207, 207, 207, 207, 207, 207, 207, 207, 207, 207, 207, 207)
torch.Size([207, 16, 1782])
Loss: %f -4337.747376203537
torch.Size([64, 14, 190])
torch.Size([14, 190])
torch.Size([16, 1, 190, 896])
torch.Size([1, 190, 1782])
torch.Size([190, 1782])
torch.Size([1782])
[96.1875 96.1875 96.1875 96.1875 

Loss: %f -4424.999908685684
torch.Size([64, 14, 222])
torch.Size([14, 222])
torch.Size([16, 1, 222, 896])
torch.Size([1, 222, 1782])
torch.Size([222, 1782])
torch.Size([1782])
[112. 112. 112. 112. 112. 112. 112. 112. 112. 112. 112. 112. 112. 112.
 112. 112.]
222
(222, 222, 222, 222, 222, 222, 222, 222, 222, 222, 222, 222, 222, 222, 222, 222)
torch.Size([222, 16, 1782])
Loss: %f -4429.499386072159
torch.Size([64, 14, 214])
torch.Size([14, 214])
torch.Size([16, 1, 214, 896])
torch.Size([1, 214, 1782])
torch.Size([214, 1782])
torch.Size([1782])
[108.1875 108.1875 108.1875 108.1875 108.1875 108.1875 108.1875 108.1875
 108.1875 108.1875 108.1875 108.1875 108.1875 108.1875 108.1875 108.1875]
214
(214, 214, 214, 214, 214, 214, 214, 214, 214, 214, 214, 214, 214, 214, 214, 214)
torch.Size([214, 16, 1782])
Loss: %f -4434.089679956436
torch.Size([64, 14, 250])
torch.Size([14, 250])
torch.Size([16, 1, 250, 896])
torch.Size([1, 250, 1782])
torch.Size([250, 1782])
torch.Size([1782])
[126.0625 126.06

KeyboardInterrupt: 

In [137]:
# Try to use GPU
# Train using ImageRNN

# SETUP
model_cnn = cnn_model(BATCH_SIZE)
model_rnn = ImageRNN()
optimizer = optim.Adam(list(model_cnn.parameters()) + list(basic_rnn.parameters()))
len_data = len(primus.training_list) + len(primus.validation_list)

model_cnn.to(device)
model_rnn.to(device)

for epoch in range(N_EPOCHS):
    train_loss = 0.
    train_acc = 0.
    model_cnn.train()
    model_rnn.train()
    
    for i in range(0, len_data, BATCH_SIZE):
        # zero parameter gradients
        optimizer.zero_grad()
        
        # reset hidden states
        model_rnn.hidden = basic_rnn.init_hidden()
        
        # Get inputs
        batch = primus.nextBatch(params)

        data = batch['inputs'] # size (batch, height, width, channels)
        data = data.to(device)
        #print(data)
        #print(data.shape)
        max_input_length = data.shape[2]
        
        # list of indices, values, shape
        seq_len = int(batch['seq_lengths'][0])
        targets = ctc_utils.sparse_tuple_from(batch['targets'])
        #print(tuple(targets[2]))
        #print(type(t[0]) for t in targets)
        #targets = torch.sparse_coo_tensor(targets[0], targets[1], tuple(targets[2]))
        targets_0 = torch.as_tensor((targets[0]))
        #print(targets.shape)
        #targets = torch.reshape(targets, (16, 1))
        padded_targets, lengths = ctc_utils.pad_sequences(batch['targets'], maxlen=max_input_length)
        padded_targets_tensor = torch.tensor(padded_targets)
        
        tensor_data = torch.from_numpy(data)
        #print(tensor_data.shape)
        tensor_data_reshape = torch.permute(tensor_data,(0,3, 1, 2))
        
        # forward, backward, optim
        cnn_output = model_cnn(tensor_data_reshape)
        output_size = 64 * 14 * cnn_output.shape[3]
        #print(cnn_output.shape)
        #print(cnn_output[0])
        print(cnn_output[0].shape)
        print(cnn_output[0][0].shape)
        
        # Change shape for rnn
        #output = torch.reshape(cnn_output, (cnn_output.shape[3], 16, 64 * 14)) # width, batch, features
        
        # SEE IF LOOP OVER OUTPUT
        output = torch.reshape(cnn_output, (16, 1, cnn_output.shape[3], 64*14))
        
        print(output.shape)
        logits = model_rnn(output)
        print(logits[0].shape) #1 by width by vocab size
        print((logits[0][0].shape))
        print(logits[0][0][0].shape) # this is logits for one column of picture
        print(batch['seq_lengths'])
        
        for t in range(0,len(logits)):
            logits[t] = torch.reshape(logits[t],(logits[t].size(1), logits[t].size(2)))
        print(logits[0].shape[0])
        
        batch_size = logits[0].shape[0]
        logits_tensor = torch.stack(logits)
        logits_tensor = torch.permute(logits_tensor, (1,0,2))
        
        #rnn_output_reshape = torch.reshape(rnn_output, (cnn_output[0].shape[2], BATCH_SIZE, N_OUTPUTS))
        #rnn_output_reshape = torch.reshape(rnn_output[0], (1, BATCH_SIZE, N_OUTPUTS))
        #rnn_output_reshape = rnn_output[0].view(-1, BATCH_SIZE, N_OUTPUTS)
        
        
        #log_probs = nn.functional.log_softmax(rnn_output, dim=0)
        #Input and target shape
        #print(rnn_output_reshape.shape)
        input_shape = (BATCH_SIZE, params['img_height'], tensor_data_reshape.shape[3],1)
        input_len = tuple([batch_size for i in range (0, 16)])
        #print(input_shape)
        print(input_len)
        print(logits_tensor.shape)
        target_shape = tuple(int(b) for b in batch['seq_lengths'])
        
        # MUST BE TENSOR, TENSOR, TUPLE, TUPLE OR TENSOR TENSOR TENSOR TENSOR
        #loss = criterion(rnn_output_reshape, padded_targets_tensor, input_len, target_shape)
        #loss = criterion(log_probs, padded_targets_tensor, input_len, target_shape)
        loss = criterion(logits_tensor, padded_targets_tensor, input_len, tuple(lengths))
        loss.backward()
        optimizer.step()
        
        train_loss += loss.detach().item()
        train_acc += 0
        print("Loss: %f", train_loss)
        
    #model.eval()
    print('training loss:')
    print(train_loss)
        

NVIDIA GeForce RTX 3070 with CUDA capability sm_86 is not compatible with the current PyTorch installation.
The current PyTorch install supports CUDA capabilities sm_37 sm_50 sm_60 sm_70.
If you want to use the NVIDIA GeForce RTX 3070 GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/



RuntimeError: CUDA error: no kernel image is available for execution on the device
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

In [96]:
for x in output:
    print(x.size(1))

896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896
896


In [81]:
# DEBUGGING
# Train

for epoch in range(N_EPOCHS):
    train_loss = 0.
    train_acc = 0.
    #model.train()
    
    for i in range(0, len_data, BATCH_SIZE):
        # zero parameter gradients
        optimizer.zero_grad()
        
        # reset hidden states
        basic_rnn.hidden = basic_rnn.init_hidden()
        
        # Get inputs
        batch = primus.nextBatch(params)

        data = batch['inputs'] # size (batch, height, width, channels)
        #print(data)
        #print(data.shape)
        max_input_length = data.shape[2]
        
        # list of indices, values, shape
        seq_len = int(batch['seq_lengths'][0])
        targets = ctc_utils.sparse_tuple_from(batch['targets'])
        #print(tuple(targets[2]))
        #print(type(t[0]) for t in targets)
        #targets = torch.sparse_coo_tensor(targets[0], targets[1], tuple(targets[2]))
        targets_0 = torch.as_tensor((targets[0]))
        #print(targets.shape)
        #targets = torch.reshape(targets, (16, 1))
        padded_targets, lengths = ctc_utils.pad_sequences(batch['targets'], maxlen=max_input_length)
        padded_targets_tensor = torch.tensor(padded_targets)
        
        tensor_data = torch.from_numpy(data)
        #print(tensor_data.shape)
        tensor_data_reshape = torch.permute(tensor_data,(0,3, 1, 2))
        
        # forward, backward, optim
        cnn_output = model_cnn(tensor_data_reshape)
        
        # Change shape for rnn
        output = cnn_output.view(cnn_output.size(0), cnn_output.size(1), -1)
        print(output.shape)
        output.permute(2,0,1)
        rnn_output = basic_rnn(output)
        print(rnn_output.shape)
        #print(batch['seq_lengths'])
        
        #rnn_output_reshape = torch.reshape(rnn_output, (cnn_output[0].shape[2], BATCH_SIZE, N_OUTPUTS))
        rnn_output_reshape = torch.reshape(rnn_output, (1, BATCH_SIZE, N_OUTPUTS))
        
        log_probs = nn.functional.log_softmax(rnn_output_reshape)
        #Input and target shape
        #print(rnn_output_reshape.shape)
        input_shape = (BATCH_SIZE, params['img_height'], tensor_data_reshape.shape[3],1)
        input_len = tuple([1 for i in range (0, BATCH_SIZE)])
        #print(input_shape)
        target_shape = tuple(int(b) for b in batch['seq_lengths'])
        
        # MUST BE TENSOR, TENSOR, TUPLE, TUPLE OR TENSOR TENSOR TENSOR TENSOR
        #loss = criterion(rnn_output_reshape, padded_targets_tensor, input_len, target_shape)
        loss = criterion(log_probs, padded_targets_tensor, input_len, target_shape)
        
        loss.backward()
        optimizer.step()
        
        train_loss += loss.detach().item()
        train_acc += 0
        
    #model.eval()
    print('training loss:')
    print(train_loss)
        

torch.Size([16, 64, 2884])


AttributeError: 'int' object has no attribute 'size'

In [22]:
torch.load('./Models/Semantic-Model.zip')

RuntimeError: [enforce fail at inline_container.cc:115] . file in archive is not in a subdirectory: semantic_model.data-00000-of-00001