## Training Unpruned Model [Not Our Approach]

In [1]:
import torch
import torch.optim as optim
from get_loader import get_loader
from torchvision import transforms
import New_Pruned_Model
import torch.nn as nn
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from utils import print_examples
from New_Pruned_Model import EncoderCNN

In [2]:
mse_loss = nn.MSELoss()
def match_hidden_states(sub_network_hidden_states, decoder_hidden_states , num_selected_layers = 3):
    # Here we have to select the layers coming with decoder network
    num_total_layers = len(decoder_hidden_states)
    selected_layers_indices = [num_selected_layers - 1]
    for i in range(num_selected_layers - 1):
        selected_layers_indices.append(int((num_total_layers - 1) * (i + 1)/num_selected_layers))
    
    selected_decoder_hidden_states = [decoder_hidden_states[idx] for idx in selected_layers_indices]
    loss = 0 
    for sub_state, dec_state in zip(sub_network_hidden_states, selected_decoder_hidden_states):
        loss += mse_loss(sub_state, dec_state)
    return loss

In [3]:
def train():
    mse_loss_2 = nn.MSELoss()
    transform = transforms.Compose([transforms.Resize((350,350)),transforms.RandomCrop((256,256)),
                                transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    images_path , caption_path = r"D:\ML\Korea\Jishu\Jishu\rsicd\images" , r"D:\ML\Korea\Jishu\Jishu\rsicd\captions.csv"
    pruned_resnet_model_path = r"D:\ML\Korea\Jishu\Jishu\Cnn_Pruning\Pruned_Resnet\fine_tuned_model.pth"
    
    BATCH_SIZE = 32
    data_loader , dataset = get_loader(images_path,caption_path ,transform,batch_size = BATCH_SIZE,num_workers=4)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_epochs = 20
    learning_rate = 3e-4
    trg_vocab_size = len(dataset.vocab)
    embedding_size = 512
    num_heads = 8
    num_decoder_layers = 2
    dropout = 0.10
    pad_idx=dataset.vocab.stoi["<PAD>"]
    save_model = True
    writer =SummaryWriter("runs/loss_plot")
    step = 0
    encoder_regularization_penalty = 0.01
    decoder_regularizartion_penalty = 0.01 
    
    model = New_Pruned_Model.EncodertoDecoder(embeding_size=embedding_size, trg_vocab_size=trg_vocab_size, num_heads=num_heads, num_decoder_layers=4, dropout=dropout).to(device)
    model.load_state_dict(torch.load(r'D:\ML\Korea\Jishu\Jishu\Final_Docs\Complete_Model_Pruning\model.pth', map_location=device))
    model = model.to(device)
    
    # Now we define the pruned model
    pruned_model = New_Pruned_Model.PrunedEncodertoDecoder(embeding_size=embedding_size,
                            trg_vocab_size=trg_vocab_size, num_heads=num_heads,
                            num_decoder_layers=num_decoder_layers,
                            dropout=dropout , pruned_resnet_model_path = pruned_resnet_model_path).to(device)
    
    
    optimizer = optim.Adam(pruned_model.parameters(),lr = learning_rate)
    criterion2 = nn.CrossEntropyLoss(ignore_index=pad_idx)
    
    l = []
    for epoch in range(num_epochs):
        print(f"[Epoch {epoch} / {num_epochs}]")
        model.eval()
        pruned_model.train()
        Total_loss = 0.0
        for idx, (images, captions) in tqdm(enumerate(data_loader), total=len(data_loader), leave=False):
            images = images.to(device)
            captions = captions.to(device)
            with torch.no_grad():
                output , hidden_original_decoder_outputs = model(images, captions[:-1])
                
            pruned_model_outputs , hidden_pruned_decoder_outputs = pruned_model(images , captions[:-1])
            pruned_model_outputs = pruned_model_outputs.reshape(-1 , pruned_model_outputs.shape[2])
            target = output.reshape(-1 , output.shape[2]).to(device)
            
            optimizer.zero_grad()
            # print(pruned_model_outputs.shape , target.shape)
            loss_match = mse_loss_2(pruned_model_outputs , target)
            
            # Compute the L2 Regularization loss of the encoder weights
            l2_reg = pruned_model.encoder.compute_penalty(encoder_regularization_penalty)
            mse_loss = match_hidden_states(hidden_pruned_decoder_outputs,hidden_original_decoder_outputs)
            dec_loss = mse_loss * decoder_regularizartion_penalty
            
            # Now calculate the total loss
            total_loss = loss_match + l2_reg + dec_loss
            lossofepoch = total_loss.item()
            Total_loss += lossofepoch
            total_loss.backward()
            
            torch.nn.utils.clip_grad_norm_(pruned_model.parameters(),max_norm=1)
            optimizer.step()
            writer.add_scalar("Training Loss",Total_loss,global_step=step)
            step+=1
            
        LOSS = Total_loss / len(data_loader)
        l.append(LOSS) 
        print("Loss of the epoch is", Total_loss / len(data_loader))
        torch.save(pruned_model , 'unpruned_model_final.pth')
            

        pruned_model.eval()
        print_examples(pruned_model, device, dataset)

In [4]:
train()

  model.load_state_dict(torch.load(r'D:\ML\Korea\Jishu\Jishu\Final_Docs\Complete_Model_Pruning\model.pth', map_location=device))
  self.pruned_image_encoder = torch.load(pruned_resnet_model_path)


[Epoch 0 / 20]


  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
                                                   

Loss of the epoch is 31.23588681407564




Example 1 CORRECT:  <SOS> the roundabout is surrounded by eight buildings . <EOS> 


Example 1 OUTPUT:  many green trees are in a piece of green meadow . <EOS>
Example 1 CORRECT:  <SOS> many pieces of agricultural land and some scattered buildings are together . <EOS> 


Example 1 OUTPUT:  many green trees are in a piece of yellow bareland . <EOS>


[Epoch 1 / 20]


                                                   

Loss of the epoch is 9.78528596798952




Example 1 CORRECT:  <SOS> a large number of trees were planted around the house . <EOS> 


Example 1 OUTPUT:  many green trees are around a square with a pond . <EOS>
Example 1 CORRECT:  <SOS> it 's a big piece of mountain . <EOS> 


Example 1 OUTPUT:  many green trees are in a forest . <EOS>


[Epoch 2 / 20]


                                                   

Loss of the epoch is 2.7596474512828535




Example 1 CORRECT:  <SOS> four white storage tanks are close to a crossroads on a bare piece of land . <EOS> 


Example 1 OUTPUT:  many green trees are around a square . <EOS>
Example 1 CORRECT:  <SOS> the distictive terminal building embraces the airplanes stopped here . <EOS> 


Example 1 OUTPUT:  many green trees are around a square with a circle . <EOS>


[Epoch 3 / 20]


                                                   

Loss of the epoch is 1.1501898534234514




Example 1 CORRECT:  <SOS> several buildings are close to a playground next to a road . <EOS> 


Example 1 OUTPUT:  many green trees and some buildings are in a resort . <EOS>
Example 1 CORRECT:  <SOS> it is a piece of bareland . <EOS> 


Example 1 OUTPUT:  the desert is very vast . <EOS>


[Epoch 4 / 20]


                                                   

Loss of the epoch is 0.9545846092215913




Example 1 CORRECT:  <SOS> this piece of the forest is green and dense . <EOS> 


Example 1 OUTPUT:  many green trees are in a forest . <EOS>
Example 1 CORRECT:  <SOS> some green trees are close to a central square building . <EOS> 


Example 1 OUTPUT:  many green trees are around a square with a circle . <EOS>


[Epoch 5 / 20]


                                                   

Loss of the epoch is 0.9154164511273165




Example 1 CORRECT:  <SOS> this square sits next to a parking lot crammed with cars . <EOS> 


Example 1 OUTPUT:  many green trees and some buildings are in a resort . <EOS>
Example 1 CORRECT:  <SOS> some people are sparsely in a piece of green meadow . <EOS> 


Example 1 OUTPUT:  many green trees are in a forest . <EOS>


[Epoch 6 / 20]


                                                   

Loss of the epoch is 0.8896345334340335




Example 1 CORRECT:  <SOS> the school is on the roadside . <EOS> 


Example 1 OUTPUT:  many green trees and some buildings are in a resort with a pond . <EOS>
Example 1 CORRECT:  <SOS> many cars are in an irregular parking lot surrounded by some green trees . <EOS> 


Example 1 OUTPUT:  many green trees are in a forest . <EOS>


[Epoch 7 / 20]


                                                   

Loss of the epoch is 0.8644981198187726




Example 1 CORRECT:  <SOS> many trees are planted on both sides of the road . <EOS> 


Example 1 OUTPUT:  many green trees are in a forest . <EOS>
Example 1 CORRECT:  <SOS> many green trees are on both sides of a river with a bridge over it . <EOS> 


Example 1 OUTPUT:  many green trees and some buildings are in a park . <EOS>


[Epoch 8 / 20]


                                                   

Loss of the epoch is 0.8445473491958982




Example 1 CORRECT:  <SOS> this <UNK> is patting the flat yellow beach . <EOS> 


Example 1 OUTPUT:  the sea is very beautiful . <EOS>
Example 1 CORRECT:  <SOS> some buildings and green trees are around a playground and several basketball fields . <EOS> 


Example 1 OUTPUT:  many green trees and some buildings are in a resort with a pond . <EOS>


[Epoch 9 / 20]


                                                   

Loss of the epoch is 0.8283466185109343




Example 1 CORRECT:  <SOS> a squared bareland is surrounded by roads . <EOS> 


Example 1 OUTPUT:  many green trees are in a piece of yellow desert . <EOS>
Example 1 CORRECT:  <SOS> a <UNK> arrangement of the buildings around the sparse vegetation in . <EOS> 


Example 1 OUTPUT:  many buildings and some green trees are in a commercial area . <EOS>


[Epoch 10 / 20]


                                                   

Loss of the epoch is 0.8112595766922678




Example 1 CORRECT:  <SOS> many cars are parked on two sides of a road . <EOS> 


Example 1 OUTPUT:  many green trees and some buildings are in a resort . <EOS>
Example 1 CORRECT:  <SOS> the <UNK> - shaped square contains bare land and trees . <EOS> 


Example 1 OUTPUT:  many green trees are around a square . <EOS>


[Epoch 11 / 20]


                                                   

Loss of the epoch is 0.7991010351266846




Example 1 CORRECT:  <SOS> some boats are scattered in a port near a pier . <EOS> 


Example 1 OUTPUT:  a large number of trees are planted on both sides of the river . <EOS>
Example 1 CORRECT:  <SOS> school buildings <UNK> in size , concentrated in the south . <EOS> 


Example 1 OUTPUT:  many green trees and some buildings are in a resort with a pond . <EOS>


[Epoch 12 / 20]


                                                   

Loss of the epoch is 0.7875554970639039




Example 1 CORRECT:  <SOS> a row of gray roofed houses near a baseball field . <EOS> 


Example 1 OUTPUT:  many green trees are around a square with some green trees . <EOS>
Example 1 CORRECT:  <SOS> some green buildings and trees are near a viaduct with a circle with a large building on it . <EOS> 


Example 1 OUTPUT:  many green trees are around a square with some green trees . <EOS>


[Epoch 13 / 20]


                                                   

Loss of the epoch is 0.7764074427887494




Example 1 CORRECT:  <SOS> it is a piece of khaki bareland . <EOS> 


Example 1 OUTPUT:  it 's a piece of yellow desert . <EOS>
Example 1 CORRECT:  <SOS> some white snow covers part of the irregular mountain caki . <EOS> 


Example 1 OUTPUT:  many green trees are in a forest . <EOS>


[Epoch 14 / 20]


                                                   

Loss of the epoch is 0.7670063949153644




Example 1 CORRECT:  <SOS> many grey white buildings and some green trees are located in a dense residential area . <EOS> 


Example 1 OUTPUT:  many buildings and some green trees are in a dense residential area . <EOS>
Example 1 CORRECT:  <SOS> red buildings are on either side of the white church . <EOS> 


Example 1 OUTPUT:  many green trees are around a square with a pond . <EOS>


[Epoch 15 / 20]


                                                   

Loss of the epoch is 0.7591765388170132




Example 1 CORRECT:  <SOS> some green trees are around a polygonal center building . <EOS> 


Example 1 OUTPUT:  many green trees are around a square with some green trees . <EOS>
Example 1 CORRECT:  <SOS> the green hills are next to each other in an uninterrupted line . <EOS> 


Example 1 OUTPUT:  many green trees are in a forest . <EOS>


[Epoch 16 / 20]


                                                   

Loss of the epoch is 0.7522077395789127




Example 1 CORRECT:  <SOS> a large white building is near a parking lot and a road with some green trees . <EOS> 


Example 1 OUTPUT:  many green trees are around a square with some green trees . <EOS>
Example 1 CORRECT:  <SOS> a multilateral green pond is surrounded by many green trees . <EOS> 


Example 1 OUTPUT:  many green trees are around a square of a curved yellow river . <EOS>


[Epoch 17 / 20]


                                                   

Loss of the epoch is 0.7447813768901735




Example 1 CORRECT:  <SOS> the stretch of turbid saddle shaped river is across the forest . <EOS> 


Example 1 OUTPUT:  many green trees are in a forest . <EOS>
Example 1 CORRECT:  <SOS> the cars were parked neatly on the road . <EOS> 


Example 1 OUTPUT:  many green trees are around a square with a circle center . <EOS>


[Epoch 18 / 20]


                                                   

Loss of the epoch is 0.7379815552342852




Example 1 CORRECT:  <SOS> many buildings and green trees are in a dense residential area . <EOS> 


Example 1 OUTPUT:  many green trees and some buildings are in a resort . <EOS>
Example 1 CORRECT:  <SOS> it 's a big piece of naked dirt . <EOS> 


Example 1 OUTPUT:  the desert is very dry . <EOS>


[Epoch 19 / 20]


                                                   

Loss of the epoch is 0.7317658195006829




Example 1 CORRECT:  <SOS> an island on a lake surrounded by many green trees are in a park . <EOS> 


Example 1 OUTPUT:  many green trees and some buildings are in a resort . <EOS>
Example 1 CORRECT:  <SOS> the water in the sea is very rough . <EOS> 


Example 1 OUTPUT:  a large number of trees are planted around the lake . <EOS>


