In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

Structure: 

![STRUCTURE](/home/ocr/teluguOCR/Structure.webp) 

Hyper parameters of the whole structure

In [None]:
# out look of the model
Number_of_images = 1000
Image_size = (32, 32) # (height, width)
Image_embedding_size = 5000
Text_embedding_size = 358
Max_Number_of_Words = 350

# Joiner Embedder parameters
Joiner_Input_size = Image_embedding_size #5000
Joiner_output_size = Text_embedding_size #358

# LSTM parameters for the RNN
LSTM_Input_size = Joiner_output_size #358
LSTM_hidden_size = LSTM_Input_size #358
LSTM_num_layers = 1
LSTM_output_size = LSTM_hidden_size #358

# reverse Embedding parameters
Reverse_Input_size = LSTM_output_size #358
Reverse_output_size = Text_embedding_size #358

drop_prob = 0.2

In [None]:
acchulu = ['అ', 'ఆ', 'ఇ', 'ఈ', 'ఉ', 'ఊ', 'ఋ', 'ౠ', 'ఌ', 'ౡ', 'ఎ', 'ఏ', 'ఐ', 'ఒ', 'ఓ', 'ఔ', 'అం', 'అః']
hallulu = ['క', 'ఖ', 'గ', 'ఘ', 'ఙ',
           'చ', 'ఛ', 'జ', 'ఝ', 'ఞ',
           'ట', 'ఠ', 'డ', 'ఢ', 'ణ',
           'త', 'థ', 'ద', 'ధ', 'న',
           'ప', 'ఫ', 'బ', 'భ', 'మ',
           'య', 'ర', 'ల', 'వ', 'శ', 'ష', 'స', 'హ', 'ళ', 'క్ష', 'ఱ', 'ఴ', 'ౘ', 'ౙ','ౚ']
vallulu = ['ా', 'ి', 'ీ', 'ు' , 'ూ', 'ృ', 'ౄ', 'ె', 'ే', 'ై', 'ొ', 'ో', 'ౌ', 'ం', 'ః', 'ఁ', 'ఀ', 'ఄ', 'ౕ', 'ౖ', 'ౢ' ]
connector = ['్']
numbers = ['౦', '౧', '౨', '౩', '౪', '౫', '౬', '౭', '౮', '౯']
splcharacters= [' ', '!', '"', '#', '$', '%', '&', "'", '(', ')',
              '*', '+', ',', '-', '.', '/', ':', ';', '<', '=', '>', '?', '@', '[',
              '\\', ']', '^', '_', '`', '{', '|', '}', '~', '1','2', '3', '4', '5', '6', '7', '8', '9', '0', 'ఽ']
spl = splcharacters + numbers

bases = acchulu + hallulu + spl
vms = vallulu
cms = hallulu

characters = bases+vms+cms+connector

base_mapping = {}
i = 1
for x in bases:
  base_mapping[x] = i
  i+=1

vm_mapping = {}
i = 1
for x in vms:
  vm_mapping[x] = i
  i+=1

cm_mapping = {}
i = 1
for x in cms:
  cm_mapping[x] = i
  i+=1

# creates a list of ductionaries with each dictionary reporesenting a term
def wordsDicts(s):
  List = []
  for i in range(len(s)):
    x = s[i]
    prev = ''
    if i > 0: prev = s[i-1]
    #----------------------------------is it a base term-----------------------
    if((x in acchulu or x in hallulu)  and prev != connector[0]):
      List.append({})
      List[-1]['base'] = x
    #----------------------------if it is a consonant modifier-----------------
    elif x in hallulu and prev == connector[0]:
      if(len(List) == 0):
        print(x)
      if('cm' not in List[-1]): List[-1]['cm'] = []
      List[len(List)-1]['cm'].append(x)

      #---------------------------if it is a vowel modifier--------------------
    elif x in vallulu:
      if(len(List) == 0):
        print(x)

      if('vm' not in List[-1]): List[-1]['vm'] = []
      List[len(List)-1]['vm'].append(x)

      #----------------------------it is a spl character-----------------------
    elif x in spl:
      List.append({})
      List[len(List)-1]['base'] = x
    else:
      continue
  return List

def one_hot_encoder(s):
  List = wordsDicts(s)
  onehot = []
  for i in range(len(List)):
    D = List[i]
    onehotbase=  [0 for _ in range(len(acchulu) +  len(hallulu) + len(spl))]
    onehotvm =  [0 for _ in range(len(vallulu))]
    onehotcm =  [0 for _ in range(len(hallulu))]   
    onehotbase[base_mapping[D['base']]-1] = 1
    if('vm' in D):
      for j in D['vm']:
        onehotvm[vm_mapping[j]-1] = 1
    if('cm' in D):
      for j in D['cm']:
        onehotcm[cm_mapping[j]-1] = 1
    onehoti = [0, 0] + onehotbase + onehotvm + onehotcm # length of 112 + 21 + 40 + 2 = 175
    onehot.append(onehoti)
  start = [0 for _ in range(175)]
  end = [0 for _ in range(175)]
  start[0] = 1
  end[1] = 1
  onehot = [start] + onehot + [end]
  encoded = torch.tensor(onehot).float().to(device)
  return encoded

def One_Hot_Decoder(List):
  x = ""
  for onehoti in List:
    onehoti[onehoti >= 0.3] = 1
    onehoti[onehoti < 0.3] = 0
    for i in range(0, 112):
      if onehoti[i+2] == 1:
        x += bases[i]
    for i in range(133, 173):
      if onehoti[i+2] == 1:
        x += connector[0] 
        x += cms[i-133]
    for i in range(112,133):
      if onehoti[i+2] == 1:
        x += vms[i-112]
  return x


In [None]:
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

class EncoderCNN(nn.Module):
    def __init__(self) -> None:
        super(EncoderCNN, self).__init__()
        # input: 30x500
        
        self.conv1 = nn.Conv2d(1, 10, kernel_size=(3, 101), stride=(2, 1), padding=(0, 0))
        # self.conv12 = nn.Conv2d(10, 10, kernel_size=(1, 51), stride=(1, 1), padding=(0, 0))

        self.conv2 = nn.Conv2d(10, 20, kernel_size=(3, 51), stride=(1, 1), padding=(0, 0))
        # self.conv22 = nn.Conv2d(20, 20, kernel_size=(1, 26), stride=(1, 1), padding=(0, 0))

        self.conv3 = nn.Conv2d(20, 30, kernel_size=(3, 51), stride=(1, 1), padding=(0, 0))
        # self.conv32 = nn.Conv2d(30, 30, kernel_size=(1, 26), stride=(1, 1), padding=(0, 0))

        self.conv4 = nn.Conv2d(30, 40, kernel_size=(3, 101), stride=(2, 1), padding=(0, 0))
        # self.conv42 = nn.Conv2d(40, 40, kernel_size=(1, 51), stride=(1, 1), padding=(0, 0))

        self.conv5 = nn.Conv2d(40, 50, kernel_size=(3, 101), stride=(2, 1), padding=(0, 0))
        # self.conv52 = nn.Conv2d(50, 50, kernel_size=(1, 51), stride=(1, 1), padding=(0, 0))
        # output: 50 x 1 x 100

    def forward(self, x):
        x = F.relu(self.conv1(x))
        # x = F.relu(self.conv12(x))
        x = F.relu(self.conv2(x))
        # x = F.relu(self.conv22(x))
        x = F.relu(self.conv3(x))
        # x = F.relu(self.conv32(x))
        x = F.relu(self.conv4(x))
        # x = F.relu(self.conv42(x))
        x = F.relu(self.conv5(x))
        # x = F.relu(self.conv52(x))
        
        # flatten the output
        x = x.view(x.size(0), 1, -1)
        return x

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=0.01)
        return optimizer

# encoder = EncoderCNN().to(device)
# image = torch.randn(20, 1, 30, 500).to(device)
# output = encoder(image)
# print(output.shape)

In [None]:
# class EncoderCNN(nn.Module):
#     def __init__(self):
#         super(EncoderCNN, self).__init__()
        
#         # Convolutional layers
#         # input size: (batch_size, 1, 300, 300)
#         self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1)
#         self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
#         self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
#         self.conv4 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
#         self.conv5 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
        
#         self.relu = nn.LeakyReLU(negative_slope = 0.2)
#         # Pooling layers
#         self.pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)

#         # Fully connected layers
#         self.fc1 = nn.Linear(512 * 9 * 9, 4096)
#         self.fc2 = nn.Linear(4096, 2048)
#         self.fc3 = nn.Linear(2048, 1024)
#         self.fc4 = nn.Linear(1024, Image_embedding_size)

#     def forward(self, x):
#         # Input size: (batch_size, 1, 300, 300)

#         # Convolutional layers with ReLU activation and pooling
#         x = self.pool(self.relu(self.conv1(x))) # (batch_size, 64, 150, 150)
#         x = self.pool(self.relu(self.conv2(x))) # (batch_size, 128, 75, 75)
#         x = self.pool(self.relu(self.conv3(x))) # (batch_size, 256, 37, 37)
#         x = self.pool(self.relu(self.conv4(x))) # (batch_size, 512, 18, 18)
#         x = self.pool(self.relu(self.conv5(x))) # (batch_size, 512, 9, 9)

#         # Flatten the output before fully connected layers
#         x = x.view(-1, 512 * 9 * 9)

#         # Fully connected layers
#         x = F.relu(self.fc1(x))
#         x = F.relu(self.fc2(x))
#         x = F.relu(self.fc3(x))
#         x = self.fc4(x)

#         return x

#     def configure_optimizers(self):
#         optimizer = optim.Adam(self.parameters(), lr=0.01)
#         return optimizer

In [None]:
# import torch.nn as nn
# import torch.optim as optim
# import torch.nn.functional as F

# class EncoderCNN(nn.Module):
#     def __init__(self) -> None:
#         super(EncoderCNN, self).__init__()
#         # input: 32x32
#         self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=0)  # 30x30x64
#         self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)  # 30x30x64
#         self.BatchNorm1 = nn.BatchNorm2d(64, momentum=0.1)
#         # self.MaxPool1 = nn.MaxPool2d(kernel_size=2, stride=2)  # 15x15x64
#         self.dropout1 = nn.Dropout2d(p=drop_prob)

#         self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)  # 15x15x128
#         self.conv4 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)  # 15x15x128
#         self.BatchNorm2 = nn.BatchNorm2d(128, momentum=0.1)
#         self.MaxPool2 = nn.MaxPool2d(kernel_size=2, stride=2)  # 7x7x128
#         self.dropout2 = nn.Dropout2d(p=drop_prob)

#         self.conv5 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)  # 7x7x256
#         self.conv6 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)  # 7x7x256
#         self.BatchNorm3 = nn.BatchNorm2d(256, momentum=0.1)
#         self.MaxPool3 = nn.MaxPool2d(kernel_size=2, stride=2) # 3x3x256
#         self.dropout3 = nn.Dropout2d(p=drop_prob)

#         self.conv7 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)  # 3x3x512
#         self.conv8 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)  # 3x3x512
#         self.BatchNorm4 = nn.BatchNorm2d(512, momentum=0.1)
#         self.MaxPool4 = nn.MaxPool2d(kernel_size=3, stride=3) # 1x1x512
#         self.dropout4 = nn.Dropout2d(p=drop_prob)

#         # Assuming Image_embedding_size is the output size of the linear layer
#         self.Dense = nn.Sequential(
#             nn.Linear(512, 1000),
#             nn.ReLU(),
#             nn.Linear(1000, 500),
#             nn.ReLU(),
#             nn.Linear(500, Image_embedding_size)
#             )

#     def forward(self, x):
#         # input: 32x32
#         x1 = F.relu(self.conv1(x)) # 30x30x64
#         x2 = self.conv2(x1) # 30x30x64
#         x3 = F.relu(torch.add(x1, self.BatchNorm1(x2))) # skip connection of x1 and x2 (residual connection)
#         x = self.MaxPool1(x3) # 15x15x64
#         x = self.dropout1(x)

#         x1 = F.relu(self.conv3(x)) # 15x15x128
#         x2 = self.conv4(x1) # 30x30x64
#         x3 = F.relu(torch.add(x1, self.BatchNorm2(x2))) # skip connection of x1 and x2 (residual connection)
#         x = self.MaxPool2(x3) # 7x7x128
#         x = self.dropout2(x)

#         x1 = F.relu(self.conv5(x)) # 7x7x256
#         x2 = self.conv6(x1) # 30x30x64
#         x3 = F.relu(torch.add(x1, self.BatchNorm3(x2))) # skip connection of x1 and x2 (residual connection)
#         x = self.MaxPool3(x3) # 3x3x256
#         x = self.dropout3(x)

#         x1 = F.relu(self.conv7(x)) # 3x3x512
#         x2 = self.conv8(x1) # 30x30x64
#         x3 = F.relu(torch.add(x1, self.BatchNorm4(x2))) # skip connection of x1 and x2 (residual connection)
#         x = self.MaxPool4(x3) # 1x1x512
#         x = self.dropout4(x)

#         # Reshape before passing through linear layers
#         x = x.view(x.size(0), -1)
#         x = self.Dense(x)
#         return x

#     def configure_optimizers(self):
#         optimizer = optim.Adam(self.parameters(), lr=0.01)
#         return optimizer


In [None]:
class LSTM_Net(nn.Module):
    def __init__(self) -> None:
        super(LSTM_Net, self).__init__()
        # embedding layer sizes
        self.einput_size = Joiner_Input_size #358
        self.eoutput_size = Joiner_output_size #17500
        # LSTM parameters
        self.embed_size = LSTM_Input_size #17500
        self.hidden_size = LSTM_hidden_size #17500
        self.num_layers = LSTM_num_layers #1
        # reverse embedding layer sizes
        self.Rinput_size = Reverse_Input_size #17500
        self.Routput_size = Reverse_output_size #358
        # dense embedding layers from 358 to 17500
        self.embedding1 = nn.Linear(self.einput_size, self.eoutput_size, bias=False)
        
        # LSTM layer
        self.lstm1 = nn.LSTM(input_size = self.embed_size, hidden_size = int(self.embed_size/2) , num_layers = self.num_layers, bidirectional = True, batch_first=True, dropout = drop_prob) #50 to 100
        self.lstm2 = nn.LSTM(input_size = self.embed_size, hidden_size = int(self.embed_size/2), num_layers = self.num_layers, bidirectional = True, batch_first=True, dropout = drop_prob) #100 to 200
        self.lstm3 = nn.LSTM(input_size = self.embed_size, hidden_size = int(self.embed_size/2) , num_layers = self.num_layers, bidirectional = True, batch_first=True, dropout = drop_prob) #200 to 300

        # attention layers for the LSTM
        self.attention_Q = nn.Linear(self.Rinput_size, self.Rinput_size)
        self.attention_K = nn.Linear(self.Rinput_size, self.Rinput_size)
        self.attention_V = nn.Linear(self.Rinput_size, self.Rinput_size)

        # dense layers from 17500 to 358
        self.Dense1 = nn.Linear(self.Rinput_size, self.Routput_size, bias=False)
        
        # initialise the weights of the embedding layers
        self.relu = nn.ReLU()
         
    def init_hidden(self, batch_size):
        self.hidden1 = (torch.zeros(2*self.num_layers, batch_size, int(self.embed_size/2)).to(device),
                torch.zeros(2*self.num_layers, batch_size, int(self.embed_size/2)).to(device))

        self.hidden2 = (torch.zeros(2*self.num_layers, batch_size, int(self.embed_size/2)).to(device),
                torch.zeros(2*self.num_layers, batch_size, int(self.embed_size/2)).to(device))

        self.hidden3 = (torch.zeros(2*self.num_layers, batch_size, int(self.embed_size/2)).to(device),
                torch.zeros(2*self.num_layers, batch_size, int(self.embed_size/2)).to(device))
        

    def forward(self, input, New = False):
        if New: # if the input is the image embedding then reset the hidden layers to zeros.
            self.init_hidden(input.shape[0])
            input = self.embedding1(input) # 358 to 17500 
            
        # LSTM layers
        output1, self.hidden1 = self.lstm1(input, self.hidden1)
        output2, self.hidden2 = self.lstm2(output1, self.hidden2)
        output3, self.hidden3 = self.lstm3(output2, self.hidden3)

        # attention layer
        Q = self.attention_Q(output3)
        K = self.attention_K(output3)
        V = self.attention_V(output3)
        attention = torch.bmm(Q, K.transpose(1, 2))
        attention = F.softmax(attention, dim=2)
        attention = torch.bmm(attention, V)
        
        # dense layer
        attention = F.relu(attention)
        attention = self.Dense1(attention)

        return attention

Training

In [None]:
saved_model_losses_min = []
saved_model_losses_max = []
Losses = []

In [None]:
cnn = EncoderCNN().to(device)
network = LSTM_Net().to(device)

# cnn.load_state_dict(torch.load('/home/ocr/teluguOCR/Saved_Models/CNN_latest.pth'))
# network.load_state_dict(torch.load('/home/ocr/teluguOCR/Saved_Models/Network_latest.pth'))

cnn.train()
network.train()

params = list(network.parameters()) + list(cnn.parameters())
optimizer = optim.Adam(params, lr=1e-5)

# gradient clipping
clip = 1.0
torch.nn.utils.clip_grad_norm_(params, clip, norm_type=2, error_if_nonfinite=False)

In [None]:
critereon = nn.MSELoss().cuda() if torch.cuda.is_available() else nn.MSELoss()

num_of_epochs = 5000

Images_path = "/home/ocr/teluguOCR/Dataset/Batch_Image_Tensors/Image"
Labels_path = '/home/ocr/teluguOCR/Dataset/Batch_Label_Tensors/Label'

def get_data_loader(i):
    images = torch.load(Images_path + str(i) + '.pt')
    labels = torch.load(Labels_path + str(i) + '.pt')
    labels = labels.float()
    # labels *= 1e5
    return images, labels

num = 1
Num_of_files = 50

for i in range(1, num_of_epochs + 1):
        start = time.time()
        l_min = 1e18    
        l_max = 0
        l = 0

        # if i%100 == 0:
        #     torch.save(network.state_dict(), '/home/ocr/teluguOCR/Saved_Models/Network_latest.pth')
        #     torch.save(cnn.state_dict(), '/home/ocr/teluguOCR/Saved_Models/CNN_latest.pth')

        # if i == 50:
        #     optimizer = optim.Adam(params, lr=1e-6)

        num_of_points = 0
        batchSize = 500
        for j in range(1, Num_of_files + 1):
            file_start = time.time()
            images, labels = get_data_loader(j)
            fl = 0
            images = images.to(device)
            labels = labels.to(device)
            size = images.shape[0]
            num_of_points += size
            if size > batchSize:
                for k in range(0, images.shape[0], batchSize):
                    optimizer.zero_grad()
                    images_ = images[k:min(k+batchSize, size)]
                    labels_ = labels[k:min(k+batchSize, size)]
                    features = cnn(images_)
                    # features = features.unsqueeze(1)
                    outputs = torch.zeros_like(labels_).to(device)
                    
                    outputs[:, 0, :] = network(features, New = True)[0][0]
                    for t in range(labels_.shape[1] - 1):
                        outputs[:, t+1, :] = network(labels_[:, t, :].unsqueeze(1) , New = False)[0][0]
                    
                    # outputs = torch.sigmoid(outputs)
                    loss = critereon(outputs, labels_)
                    loss.backward()
                    optimizer.step()
                    fl += loss.item()
                    del images_
                    del labels_
                    del outputs
                    del loss  
                del images
                del labels
            else:
                optimizer.zero_grad()
                features = cnn(images)
                # features = features.unsqueeze(1)
                outputs = torch.zeros_like(labels).to(device)
                
                outputs[:, 0, :] = network(features, New = True)[0][0]
                for t in range(labels.shape[1] - 1):
                    outputs[:, t+1, :] = network(labels[:, t, :].unsqueeze(1) , New = False)[0][0]

                # outputs = torch.sigmoid(outputs)
                loss = critereon(outputs, labels)
                loss.backward()
                optimizer.step()
                fl += loss.item()
                del images
                del labels
                del outputs
                del loss
            l_min = min(l_min, fl)
            l_max = max(l_max, fl) 
            l += fl 
        print(f"Epoch {i} completed in {format(time.time() - start, '.0f')} seconds with loss ({l_min}, {l_max}), {l}")
        Losses.append(l)

In [None]:
import matplotlib.pyplot as plt

plt.plot(Losses[2:], label = "Training Loss", color = 'red')
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Loss vs Epochs with BCEWithLogitsLoss")
plt.show()
print(Losses[0], Losses[-1])
print(Losses)


In [None]:
# torch.save(network.state_dict(), '/home/ocr/teluguOCR/Saved_Models/Network_1000.pth')
# torch.save(cnn.state_dict(), '/home/ocr/teluguOCR/Saved_Models/CNN_1000.pth')

In [None]:
# torch.save(torch.tensor(Losses), '/home/ocr/teluguOCR/Losses/Losses_from_4000_to_5500.pt')

In [None]:
drthdrhgdsr

Testing

In [None]:
cnn1 = EncoderCNN().to(device)
network1 = LSTM_Net().to(device)

cnn1.load_state_dict(torch.load('/home/ocr/teluguOCR/Saved_Models/CNN_600.pth'))
network1.load_state_dict(torch.load('/home/ocr/teluguOCR/Saved_Models/Network_600.pth'))

cnn1.eval()
network1.eval()
mse = nn.MSELoss()

for ind in range(1, 60):
    Images = torch.load('/home/ocr/teluguOCR/Dataset/Batch_Image_Tensors/Image' + str(ind) + '.pt').to(device)
    Labels = torch.load('/home/ocr/teluguOCR/Dataset/Batch_Label_Tensors/Label' + str(ind) + '.pt').to(device)
    loss = 0
    print(ind)
    if(Labels.shape[0] > 1000):
        for k in range(0, Labels.shape[0], 1000):
            Image = Images[k:min(k+1000, Labels.shape[0])]
            Label = Labels[k:min(k+1000, Labels.shape[0])]
            Image = cnn1(Image)
            Image = Image.unsqueeze(0)
            outputs = network1(Image, Label.shape[1])
            outputs = outputs.reshape(outputs.shape[2], outputs.shape[0], outputs.shape[3])
            loss += mse(outputs, Label)
            del Image
            del Label
            del outputs
    else:
        Images = cnn1(Images)
        Images = Images.unsqueeze(0)
        outputs = network1(Images, Labels.shape[1])
        outputs = outputs.reshape(outputs.shape[2], outputs.shape[0], outputs.shape[3])
        loss += mse(outputs, Labels)
    del Images
    del Labels
    print(loss)
    del loss

In [None]:
import matplotlib.pyplot as plt

In [None]:
cnn = EncoderCNN().to(device)
network = LSTM_Net().to(device)

cnn.load_state_dict(torch.load('/home/ocr/teluguOCR/Saved_Models/CNN_latest.pth'))
network.load_state_dict(torch.load('/home/ocr/teluguOCR/Saved_Models/Network_latest.pth'))

cnn.eval()
network.eval()

Image = torch.load('/home/ocr/teluguOCR/Dataset/Batch_Image_Tensors/Image' + str(81) + '.pt').to(device)[0]
Label = torch.load('/home/ocr/teluguOCR/Dataset/Batch_Label_Tensors/Label' + str(81) + '.pt').to(device)[0]

Image = Image.unsqueeze(0)

plt.imshow(Image[0][0].cpu().detach().numpy(), cmap = 'gray')
plt.show()

print(Label.shape)
print("actual: ", One_Hot_Decoder(Label.cpu().detach().numpy()))

features = cnn(Image)
# features = features.unsqueeze(0).unsqueeze(0)

def roundoff(output):
    # print(output.shape)
    f = torch.zeros(output.shape)
    x = torch.softmax(output[0][0][:114], dim = 0)
    # print(x.shape)
    y = torch.sigmoid(output[0][0][114:135])
    z = torch.sigmoid(output[0][0][135:])

    a = torch.argmax(x)
    f[0][0][a] = 1
    # considering only top 3 values of y
    y_arg = torch.argsort(y, descending = True)
    y_arg = y_arg[:3]
    for i in y_arg:
        f[0][0][i+114] = torch.round(y[i])
    # considering only top 4 values of z
    z_arg = torch.argsort(z, descending = True)
    z_arg = z_arg[:4]
    for i in z_arg:
        f[0][0][i+135] = torch.round(z[i])

    return f        



F_output = []
features = features.unsqueeze(0)
output = network(features, New = True).to(device)
# print(output.shape)
F_output.append(roundoff(decoder(output)).reshape(175).to(device))
# while True:
for _ in range(100):
    output = network(output, New = False).to(device)
    # print(output.shape)
    F_output.append(roundoff(decoder(output)).reshape(175).to(device))
    if(F_output[-1][1] == 1):
        break


In [None]:
fout = torch.stack(F_output).cpu().detach().numpy()
print(fout.shape)
print(fout[2][:114])
s = One_Hot_Decoder(fout)
for c in s:
    print(c, end = ' + ')
print()
print(s)