## Training Our Final Model

In [10]:
import torch
import torch.optim as optim
from get_loader import get_loader
from torchvision import transforms
import New_Pruned_Model
import torch.nn as nn
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from utils import print_examples
from New_Pruned_Model import EncoderCNN
from Model import EncodertoDecoder

mse_loss = nn.MSELoss()
def match_hidden_states(sub_network_hidden_states, decoder_hidden_states , num_selected_layers = 3):
    # Here we have to select the layers coming with decoder network
    num_total_layers = len(decoder_hidden_states)
    selected_layers_indices = [num_selected_layers - 1]
    for i in range(num_selected_layers - 1):
        selected_layers_indices.append(int((num_total_layers - 1) * (i + 1)/num_selected_layers))
    print(selected_decoder_hidden_states)
    selected_decoder_hidden_states = [decoder_hidden_states[idx] for idx in selected_layers_indices]
    loss = 0 
    for sub_state, dec_state in zip(sub_network_hidden_states, selected_decoder_hidden_states):
        loss += mse_loss(sub_state, dec_state)
    return loss

def train():
    transform = transforms.Compose([transforms.Resize((350,350)),transforms.RandomCrop((256,256)),
                                transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    images_path , caption_path = r"D:\ML\Korea\Jishu\Jishu\rsicd\images" , r"D:\ML\Korea\Jishu\Jishu\rsicd\captions.csv"
    pruned_resnet_model_path = r"D:\ML\Korea\Jishu\Jishu\Cnn_Pruning\Pruned_Resnet\fine_tuned_model.pth"
    
    BATCH_SIZE = 32
    data_loader , dataset = get_loader(images_path,caption_path ,transform,batch_size = BATCH_SIZE,num_workers=4)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_epochs = 20
    learning_rate = 3e-4
    trg_vocab_size = len(dataset.vocab)
    embedding_size = 512
    num_heads = 8
    num_decoder_layers = 2
    dropout = 0.10
    pad_idx=dataset.vocab.stoi["<PAD>"]
    save_model = True
    writer =SummaryWriter("runs/loss_plot")
    step = 0
    encoder_regularization_penalty = 0.01
    decoder_regularizartion_penalty = 0.01 
    
    model = New_Pruned_Model.EncodertoDecoder(embeding_size=embedding_size, trg_vocab_size=trg_vocab_size, num_heads=num_heads, num_decoder_layers=4, dropout=dropout).to(device)
    model.load_state_dict(torch.load(r'D:\ML\Korea\Jishu\Jishu\Final_Docs\Original_Image_Captioning_Model\model.pth', map_location=device))
    model = model.to(device)
    
    # Now we define the pruned model
    pruned_model = New_Pruned_Model.PrunedEncodertoDecoder(embeding_size=embedding_size,
                            trg_vocab_size=trg_vocab_size, num_heads=num_heads,
                            num_decoder_layers=num_decoder_layers,
                            dropout=dropout , pruned_resnet_model_path = pruned_resnet_model_path).to(device)
    
    
    optimizer = optim.Adam(pruned_model.parameters(),lr = learning_rate)
    criterion2 = nn.CrossEntropyLoss(ignore_index=pad_idx)
    
    l = []
    for epoch in range(num_epochs):
        print(f"[Epoch {epoch} / {num_epochs}]")
        model.eval()
        pruned_model.train()
        Total_loss = 0.0
        for idx, (images, captions) in tqdm(enumerate(data_loader), total=len(data_loader), leave=False):
            images = images.to(device)
            captions = captions.to(device)
            with torch.no_grad():
                output  = model(images, captions[:-1])
                
            print("Now")
            pruned_model_outputs , hidden_pruned_decoder_outputs = pruned_model(images , captions[:-1])
            pruned_model_outputs = pruned_model_outputs.reshape(-1 , pruned_model_outputs.shape[2])
            target = captions[1:].reshape(-1).to(device)
            
            optimizer.zero_grad()
            loss_match = criterion2(pruned_model_outputs , target)
            
            # Compute the L2 Regularization loss of the encoder weights
            l2_reg = pruned_model.encoder.compute_penalty(encoder_regularization_penalty)
            mse_loss = match_hidden_states(hidden_pruned_decoder_outputs,hidden_original_decoder_outputs)
            dec_loss = mse_loss * decoder_regularizartion_penalty
            
            # Now calculate the total loss
            total_loss = loss_match + l2_reg + dec_loss
            lossofepoch = total_loss.item()
            Total_loss += lossofepoch
            total_loss.backward()
            
            torch.nn.utils.clip_grad_norm_(pruned_model.parameters(),max_norm=1)
            optimizer.step()
            writer.add_scalar("Training Loss",Total_loss,global_step=step)
            step+=1
            
        LOSS = Total_loss / len(data_loader)
        l.append(LOSS) 
        print("Loss of the epoch is", Total_loss / len(data_loader))
        # torch.save(pruned_model , 'model_final_5_2_T.pth')
            

        pruned_model.eval()
        print_examples(pruned_model, device, dataset)

In [3]:
train()

  model.load_state_dict(torch.load(r'D:\ML\Korea\Jishu\Jishu\Final_Docs\Original_Image_Captioning_Model\model.pth', map_location=device))
  self.pruned_image_encoder = torch.load(pruned_resnet_model_path)


[Epoch 0 / 20]


  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
                                                   

Loss of the epoch is 31.567309242272415




Example 1 CORRECT:  <SOS> the mountain range is vast and rugged . <EOS> 


Example 1 OUTPUT:  many green trees and some buildings are in two sides of a curved river . <EOS>
Example 1 CORRECT:  <SOS> four tennis courts are close to a large stadium . <EOS> 


Example 1 OUTPUT:  a bridge is over a river with many green trees in two sides of it . <EOS>


[Epoch 1 / 20]


                                                   

Loss of the epoch is 10.97519461612373




Example 1 CORRECT:  <SOS> many green buildings and trees are located in an average residential area . <EOS> 


Example 1 OUTPUT:  many buildings and green trees are in a school . <EOS>
Example 1 CORRECT:  <SOS> the round area is a large baseball field . <EOS> 


Example 1 OUTPUT:  a baseball field is surrounded by some green trees . <EOS>


[Epoch 2 / 20]


                                                   

Loss of the epoch is 3.537011243554535




Example 1 CORRECT:  <SOS> some green trees and several buildings are around a baseball pitch . <EOS> 


Example 1 OUTPUT:  a baseball field is close to several green trees and a road . <EOS>
Example 1 CORRECT:  <SOS> the ocean is vast . <EOS> 


Example 1 OUTPUT:  a piece of ocean is near a yellow beach and some green trees . <EOS>


[Epoch 3 / 20]


                                                   

Loss of the epoch is 1.6197341238388994




Example 1 CORRECT:  <SOS> a bridge with two towers is above a bright yellow river with ships . <EOS> 


Example 1 OUTPUT:  a road is near a large piece of bareland . <EOS>
Example 1 CORRECT:  <SOS> a square is surrounded by many green trees near a road . <EOS> 


Example 1 OUTPUT:  many green trees are around a square square . <EOS>


[Epoch 4 / 20]


                                                   

Loss of the epoch is 1.3430183215973337




Example 1 CORRECT:  <SOS> many green trees and buildings are located on two sides of a railway station . <EOS> 


Example 1 OUTPUT:  many buildings and some green trees are near a viaduct . <EOS>
Example 1 CORRECT:  <SOS> oasis are in the middle of the field , the cars drive on the road . <EOS> 


Example 1 OUTPUT:  many green trees are near a viaduct with some cars . <EOS>


[Epoch 5 / 20]


                                                   

Loss of the epoch is 1.2731657699799874




Example 1 CORRECT:  <SOS> a river is among many pieces of mountain caki . <EOS> 


Example 1 OUTPUT:  it 's a piece of irregular khaki mountain . <EOS>
Example 1 CORRECT:  <SOS> the top of the mountain is naked while other parts are green . <EOS> 


Example 1 OUTPUT:  the mountain is yellow and green . <EOS>


[Epoch 6 / 20]


                                                   

Loss of the epoch is 1.227709257966476




Example 1 CORRECT:  <SOS> several green buildings and trees are located on two sides of a train station . <EOS> 


Example 1 OUTPUT:  many buildings are located on both sides of a railway station . <EOS>
Example 1 CORRECT:  <SOS> a playground with a basketball court next door is surrounded by a few buildings and plants . <EOS> 


Example 1 OUTPUT:  a playground with a football field in it is surrounded by many buildings . <EOS>


[Epoch 7 / 20]


                                                   

Loss of the epoch is 1.187538374664265




Example 1 CORRECT:  <SOS> the huge lake lies in the middle of the woods . <EOS> 


Example 1 OUTPUT:  many green trees are around an irregular pond . <EOS>
Example 1 CORRECT:  <SOS> it is a piece of yellow desert . <EOS> 


Example 1 OUTPUT:  it is a piece of yellow desert . <EOS>


[Epoch 8 / 20]


                                                   

Loss of the epoch is 1.1446820156581712




Example 1 CORRECT:  <SOS> on one side of the river was a bare land . <EOS> 


Example 1 OUTPUT:  many green trees are on two sides of a curved river . <EOS>
Example 1 CORRECT:  <SOS> a playground is next to some green trees and a white building . <EOS> 


Example 1 OUTPUT:  a playground is surrounded by many green trees and many buildings . <EOS>


[Epoch 9 / 20]


                                                      

Loss of the epoch is 1.1086940694023186




Example 1 CORRECT:  <SOS> many buildings are in an industrial area . <EOS> 


Example 1 OUTPUT:  many buildings and some green trees are in an industrial area . <EOS>
Example 1 CORRECT:  <SOS> some pieces of farmlands are together . <EOS> 


Example 1 OUTPUT:  a large number of trees were planted around the lake . <EOS>


[Epoch 10 / 20]


                                                   

Loss of the epoch is 1.0733957379729908




Example 1 CORRECT:  <SOS> many green buildings and trees are ordered in a dense residential area . <EOS> 


Example 1 OUTPUT:  many green buildings and trees are located in a dense residential area . <EOS>
Example 1 CORRECT:  <SOS> many cars are parked in a parking lot near a large building with different green trees . <EOS> 


Example 1 OUTPUT:  many cars are parked in a parking lot near a road . <EOS>


[Epoch 11 / 20]


                                                   

Loss of the epoch is 1.0447862249025157




Example 1 CORRECT:  <SOS> four planes are parked in an airport near several buildings with parking lots . <EOS> 


Example 1 OUTPUT:  many planes are parked near a terminal in an airport . <EOS>
Example 1 CORRECT:  <SOS> many buildings and some green trees are around a playground . <EOS> 


Example 1 OUTPUT:  a playground is semi - surrounded by some green trees and many buildings . <EOS>


[Epoch 12 / 20]


                                                   

Loss of the epoch is 1.0173407623279076




Example 1 CORRECT:  <SOS> the colors of the two ponds are both bright blue . <EOS> 


Example 1 OUTPUT:  a pond is near a river with some green trees . <EOS>
Example 1 CORRECT:  <SOS> a piece of sand in the desert is like fish scale . <EOS> 


Example 1 OUTPUT:  it 's a big piece of yellow desert . <EOS>


[Epoch 13 / 20]


                                                   

Loss of the epoch is 0.9917280465970764




Example 1 CORRECT:  <SOS> a football field is close to several green trees and buildings . <EOS> 


Example 1 OUTPUT:  a playground is surrounded by some green trees and buildings . <EOS>
Example 1 CORRECT:  <SOS> some red buildings are near a church next to a road with many people . <EOS> 


Example 1 OUTPUT:  a church is near a road with many cars running . <EOS>


[Epoch 14 / 20]


                                                   

Loss of the epoch is 0.9744462008487451




Example 1 CORRECT:  <SOS> the viaduct here is majestic and complicated . <EOS> 


Example 1 OUTPUT:  many green trees are near a viaduct . <EOS>
Example 1 CORRECT:  <SOS> the blue <UNK> <UNK> are next to the neighborhood . <EOS> 


Example 1 OUTPUT:  many green trees and some buildings are in a school . <EOS>


[Epoch 15 / 20]


                                                   

Loss of the epoch is 0.9482870778194242




Example 1 CORRECT:  <SOS> some boats are in a port near a large piece of green lawn . <EOS> 


Example 1 OUTPUT:  some boats are in a port near green plants and many buildings . <EOS>
Example 1 CORRECT:  <SOS> many green trees and a small pond are in a park near a road . <EOS> 


Example 1 OUTPUT:  many green trees and a pond are in a park near a road . <EOS>


[Epoch 16 / 20]


                                                   

Loss of the epoch is 0.9289245242150923




Example 1 CORRECT:  <SOS> the light green trees are on the side of the road . <EOS> 


Example 1 OUTPUT:  many cars are parked in a parking lot near several buildings . <EOS>
Example 1 CORRECT:  <SOS> here lies an esthetic square with large meadows surrounded by roads . <EOS> 


Example 1 OUTPUT:  a square and some green trees are around a circle square . <EOS>


[Epoch 17 / 20]


                                                   

Loss of the epoch is 0.9087784592534455




Example 1 CORRECT:  <SOS> it 's a big piece of mountain . <EOS> 


Example 1 OUTPUT:  it is a piece of yellow mountains . <EOS>
Example 1 CORRECT:  <SOS> many green trees are in two sides of a curved river . <EOS> 


Example 1 OUTPUT:  many green trees are found on two sides of a curved river . <EOS>


[Epoch 18 / 20]


                                                   

Loss of the epoch is 0.8935033209353732




Example 1 CORRECT:  <SOS> many sands form a piece of desert . <EOS> 


Example 1 OUTPUT:  it 's a big piece of yellow desert . <EOS>
Example 1 CORRECT:  <SOS> on one side of the river are rows of blue roofed houses . <EOS> 


Example 1 OUTPUT:  many green trees and some buildings are in a park with a pond . <EOS>


[Epoch 19 / 20]


                                                   

Loss of the epoch is 0.8830576661233798




Example 1 CORRECT:  <SOS> the railway station . <EOS> 


Example 1 OUTPUT:  many buildings and some green trees are located on both sides of a train station . <EOS>
Example 1 CORRECT:  <SOS> several large buildings and some green trees are located in a commercial area . <EOS> 


Example 1 OUTPUT:  many buildings and some green trees are in a commercial area . <EOS>




In [2]:
0.88

0.88