#### Parameters

In [1]:
import os, glob
import torch
import numpy as np
scale_factor = 2 ** 1
IMAGE_HEIGHT = 64
IMAGE_WIDTH =  150
# device = torch.device("cuda" if torch.cu+2
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = "cpu"
batch_size = 2
num_example = 2
embedding_size = 64
Max_str = 81
text_max_len = Max_str + 4
vocab = {
    " ",
    "!",
    '"',
    "#",
    "&",
    "'",
    "(",
    ")",
    "*",
    "+",
    ",",
    "-",
    ".",
    "/",
    "0",
    "1",
    "2",
    "3",
    "4",
    "5",
    "6",
    "7",
    "8",
    "9",
    ":",
    ";",
    "?",
    "A",
    "B",
    "C",
    "D",
    "E",
    "F",
    "G",
    "H",
    "I",
    "J",
    "K",
    "L",
    "M",
    "N",
    "O",
    "P",
    "Q",
    "R",
    "S",
    "T",
    "U",
    "V",
    "W",
    "X",
    "Y",
    "Z",
    "a",
    "b",
    "c",
    "d",
    "e",
    "f",
    "g",
    "h",
    "i",
    "j",
    "k",
    "l",
    "m",
    "n",
    "o",
    "p",
    "q",
    "r",
    "s",
    "t",
    "u",
    "v",
    "w",
    "x",
    "y",
    "z",
}

cfg = {
    "E": [
        64,
        64,
        128,
        128,
        "M",
        256,
        256,
        256,
        256,
        "M",
        512,
        512,
        512,
        512,
        "M",
        512,
        512,
        512,
        512,
    ],
}
encoder = {data: i for i, data in enumerate(vocab)}
decoder = {i: data for i, data in enumerate(vocab)}
"""
encoder= {"A":0,"B":1}
decoder={"0":A,"1":B}
"""
tokens = {"GO_TOKEN": 0, "END_TOKEN": 1, "PAD_TOKEN": 2}
NUM_WRITERS = 500


### Helper functions

In [2]:

def pad_str(data):
    # data:str [('hello',"what",),("on the road","where we are")]
    # data :- lenght of data is dependent on the number_example and the batch_size data[num_examples][batchsize]
    data = list(data)
    for i in range(len(data)):
        # for j in range(len(data[i])):
        # data[i] = tuple(s.ljust(text_max_len, " ") for s in data[i])
        if len(data[i]) < text_max_len:
            max_str = str()
            data[i] += " " * (text_max_len - len(data[i]))
            # data[i]=max_str
        else:
            data[i] = data[i]
    return tuple(data)


def encoding(label, decoder):
    # Label[example][batch_size]
    words = [
        torch.tensor([[decoder[char] for char in word] for word in str1])
        for str1 in label
    ]
    return words  # [examples][batch_size]


#### TextEncoder_FC

### Visual encoder

In [3]:
import torch.nn as nn
class Visual_encoder(nn.Module):
    def __init__(self) :
        super(Visual_encoder,self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=1, out_channels=100, kernel_size=3, stride=1, padding=1
            ),
            nn.ReLU(),
            nn.BatchNorm2d(100),
            nn.Conv2d(
                in_channels=100, out_channels=100, kernel_size=3, stride=1, padding=1
            ),
            nn.ReLU(),
        )

        self.conv2 = nn.Conv2d(
            in_channels=100, out_channels=32, kernel_size=3, stride=1, padding=1
        )
        self.conv3 = nn.Conv2d(
            in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1
        )
        self.conv4 = nn.Conv2d(
            in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1
        )
        self.conv5 = nn.Conv2d(
            in_channels=128, out_channels=32, kernel_size=3, stride=1, padding=1
        )

        self.upsample1 = nn.Upsample(scale_factor=2, mode="nearest")
        # self.upsample2 = nn.Upsample(scale_factor=2, mode='nearest')
        # self.upsample3 = nn.Upsample(scale_factor=2, mode='nearest')

    def forward(self, x):
        print("Shape of the Input in VGG network:-", x.shape)
        x = self.conv1(x.permute(1,0,2,3))
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.upsample1(x)
        # x=self.upsample2(x)
        # x=self.upsample3(x)
        return x


            


### Resnet Functions

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F



class AdaLN(nn.Module):
    def __init__(self, num_features, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.rho = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
        self.gamma = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
        self.beta = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.constant_(self.rho, 0.9)
        nn.init.constant_(self.gamma, 1.0)
        nn.init.constant_(self.beta, 0.0)

    def forward(self, x):
        mean = torch.mean(x, dim=[2, 3], keepdim=True)
        var = torch.var(x, dim=[2, 3], keepdim=True)
        x = (x - mean) / torch.sqrt(var + self.eps)
        return self.gamma * x + self.beta


class ResidualBlock(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        stride=1,
        norm_layer=nn.BatchNorm2d,
        activation=F.relu,
    ):
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=3,
            stride=stride,
            padding=1,
            bias=False,
        )
        self.bn1 = norm_layer(out_channels)
        self.conv2 = nn.Conv2d(
            out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn2 = norm_layer(out_channels)
        self.stride = stride
        self.activation = activation
        self.adaln = AdaLN(out_channels)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.activation(out)
        out = self.adaln(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += residual
        out = self.activation(out)
        return out


class DisModel(nn.Module):
    def __init__(self):
        super(DisModel, self).__init__()
        # define the number of layers
        self.n_layers = 6
        self.final_size = 1024
        in_dim = 1
        out_dim = 16
        self.ff_cc = nn.Conv2d(
            in_channels=in_dim,
            out_channels=out_dim,
            kernel_size=7,
            stride=1,
            padding="same",
        )
        self.res_blocks = nn.Sequential(
            *[
                ResidualBlock(
                    in_dim,
                    out_dim,
                )
                for _ in range(self.n_layers)
            ]
        )
        self.cnn_f = nn.Conv2d(
            out_dim, self.final_size, kernel_size=7, stride=1, padding="same"
        )
        self.bce = nn.BCEWithLogitsLoss()

    def forward(self, x):
        ff_cc = self.ff_cc(x)
        resnet = self.res_blocks(ff_cc)
        output = self.cnn_f(resnet)
        return output.squeeze(-1).squeeze(-1)

    def calc_dis_fake_loss(self, input_fake):

        fake_img = torch.zeros(input_fake.shape[0], self.final_size).to(device)
        resp_fake = self.forward(fake_img)
        fake_loss = self.bce(resp_fake, fake_img)

    def calc_dis_real_loss(self, input_real):
        label = torch.ones(input_real.shape[0], self.final_size).to(device=device)
        resp_real = self.forward(input_real)
        real_loss = self.bce(resp_real, label)
        return real_loss

    def calc_gen_loss(self, input_fake):
        label = torch.ones(input_fake.shape[0], self.final_size).to(device)
        resp_fake = self.forward(input_fake)
        fake_loss = self.bce(resp_fake, label)
        return fake_loss


class WriterClaModel(nn.Module):
    def __init__(self, num_writers) -> None:
        super(WriterClaModel, self).__init__()
        self.n_layers = 6
        in_dim = 1
        out_dim = 16
        self.cnn_f = nn.Conv2d(
            in_channels=in_dim,
            out_channels=out_dim,
            kernel_size=7,
            stride=1,
            padding=1,
            padding_mode="reflect",
        )
        self.res_blocks = nn.Sequential(
             *[
                ResidualBlock(
                    in_dim,
                    out_dim,
                )
                for _ in range(self.n_layers)
            ]
        )
        self.ff_cc = nn.Conv2d(
            in_channels=out_dim,
            out_channels=num_writers,
            kernel_size=7,
            stride=1,
            padding=1,
        )
        self.cross_entropy = nn.CrossEntropyLoss()

    def forward(self, x, y):
        cnn_f = self.cnn_f(x)
        resnet = self.res_blocks(cnn_f)
        ff_cc = self.ff_cc(resnet)
        loss = self.cross_entropy(ff_cc.squeeze(-1).squeeze(-1), y)
        return loss


class GenModel_FC(nn.Module):
    def __init__(self):
        super(GenModel_FC,self).__init__()
        self.enc_image = Visual_encoder().to(device)
        self.enc_text = TextEncoder_FC().to(device)
        self.dec = Decorder().to(device)
        self.linear_mix = nn.Linear(1024, 512)

    def decode(self, content, label_text):
        # decode content and style codes to an image
        self.dec(content, label_text)

    # feat_mix: b,1024,8,27
    def mix(self, feat_xs, feat_embed):
        feat_mix = torch.cat([feat_xs, feat_embed], dim=1)  # b,1024,8,27
        f = feat_mix.permute(0, 2, 3, 1)
        ff = self.linear_mix(f)  # b,8,27,1024->b,8,27,512
        return ff.permute(0, 3, 1, 2)


class Generator(nn.Module):
    def __init__(
        self, class_num, num_res_blocks=4, norm_layer=AdaLN, activation=F.leaky_relu
    ):
        super().__init__()
        self.num_res_blocks = num_res_blocks
        self.norm_layer = norm_layer
        self.activation = activation

        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=1, padding=3, bias=False)
        self.bn1 = norm_layer(64)
        self.res_blocks = nn.Sequential(
            *[
                ResidualBlock(
                    64, 64, norm_layer=self.norm_layer, activation=self.activation
                )
                for _ in range(self.num_res_blocks)
            ]
        )
        self.conv2 = nn.ConvTranspose2d(
            64, 32, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False
        )
        self.bn2 = norm_layer(32)
        self.conv3 = nn.ConvTranspose2d(
            32, 16, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False
        )
        self.bn3 = norm_layer(16)
        self.conv4 = nn.ConvTranspose2d(
            16, 3, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.linear = nn.Linear(3, out_features=class_num)
        self.tanh = nn.Tanh()

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.res_blocks(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.conv4(x)
        x = self.linear(x)
        return x


class RecModel(nn.Module):
    def __init__(self):
        super(RecModel,self).__init__()
        self.enc = Encoder()
        self.dec = Decorder()

    def forward(self, image, text):
        visual_out = self.enc(image)
        text_visual = self.dec(visual_out, text)
        return text_visual


# Attention

In [5]:
import torch
from torch import nn


class Head(nn.Module):
    def __init__(self, infeature, out_feature):
        super().__init__()

        # Q, K, V weight matrices for each head
        self.wq = nn.Linear(infeature, out_feature, bias=False)
        self.wk = nn.Linear(infeature, out_feature, bias=False)
        self.wv = nn.Linear(infeature, out_feature, bias=False)
        self.scale = 1.0 / (infeature ** 0.5)

        # Output projection matrix
        self.proj = nn.Linear(infeature, out_feature, bias=False)

    def forward(self, x):
        # x shape: [batch_size, num_channels* image_height* image_width]
        batch, CHW = x.shape
        # Reshape input to [batch_size, num_channels*image_height, image_width]
        # x = x.reshape(x.size(0), -1, x.size(1))

        # Compute Q, K, V matrices for each head
        q = self.wq(x)  # q shape: [batch_size, num_channels*image_height, d_model]
        k = self.wk(x)  # k shape: [batch_size, num_channels*image_height, d_model]
        v = self.wv(x)  # v shape: [batch_size, num_channels*image_height, d_model]
        weights = torch.matmul(q, k.transpose(-2, -1))
        weights = weights * self.scale
        weights = nn.functional.softmax(weights, dim=-1)

        # Apply attention weights to values
        output = torch.matmul(weights, v)
        return output


class MultiHeadAttention(nn.Module):
    "Multiple heads of the self_attention in parallel"

    def __init__(self, infeature, out_feature, num_heads, dropout):
        super().__init__()
        self.heads = nn.ModuleList(
            [
                Head(infeature=infeature, out_feature=out_feature)
                for _ in range(num_heads)
            ]
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads])

        out = self.dropout(out)
        print(out.shape)
        return out


class Cross_attention(nn.Module):
    def __init__(self, infeature, out_feature):
        super().__init__()

        # Q, K, V weight matrices for each head
        self.wq = nn.Linear(infeature, out_feature, bias=False)
        self.wk = nn.Linear(infeature, out_feature, bias=False)
        self.wv = nn.Linear(infeature, out_feature, bias=False)
        self.scale = 1.0 / (infeature ** 0.5)

        # Output projection matrix
        self.proj = nn.Linear(infeature, out_feature, bias=False)

    def forward(self, decoder, encoder):
        # x shape: [batch_size, num_channels* image_height* image_width]
        # batch, CHW = x.shape

        # Reshape input to [batch_size, num_channels*image_height, image_width]
        # x = x.reshape(x.size(0), -1, x.size(1))

        # Compute Q, K, V matrices for each head
        q = self.wq(
            decoder
        )  # q shape: [batch_size, num_channels*image_height, d_model]
        k = self.wk(
            encoder
        )  # k shape: [batch_size, num_channels*image_height, d_model]
        v = self.wv(
            decoder
        )  # v shape: [batch_size, num_channels*image_height, d_model]
        weights = torch.matmul(q, k.transpose(-2, -1))
        weights = weights * self.scale
        weights = nn.functional.softmax(weights, dim=-1)

        # Apply attention weights to values
        output = torch.matmul(weights, v)
        return output


class MultiHead_CrossAttention(nn.Module):
    "Multiple heads of the self_attention in parallel"

    def __init__(self, infeature, out_feature, num_heads, dropout):
        super().__init__()
        self.heads = nn.ModuleList(
            [
                Cross_attention(infeature=infeature, out_feature=out_feature)
                for _ in range(num_heads)
            ]
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, encoder, decoder):
        out = torch.cat([h.forward(encoder, decoder) for h in self.heads])

        out = self.dropout(out)
        print(out.shape)
        return out


### Decoder

In [6]:


class Decorder(torch.nn.Module):
    def __init__(self, in_feature=32, out_feature=128, dropout=0.3):
        super().__init__()
        self.dropout = dropout
        self.in_feature = in_feature
        self.out_feature = out_feature
        self.TextStyle = TextEncoder_FC().to(device)
        self.in_feature = embedding_size * IMAGE_HEIGHT * IMAGE_WIDTH
        self.linear_upsampling = nn.Linear(
            embedding_size * text_max_len, self.in_feature
        )
        self.linear_downsampling = nn.Linear(
            in_features=self.in_feature, out_features=self.out_feature
        )
        self.block_with_attention = LayerNormLinearDropoutBlock(
            in_features=self.in_feature,
            out_features=self.out_feature,
            num_heads=2,
            dropout_prob=0.2,
            attention=True,
        )
        self.block_without_attention = LayerNormLinearDropoutBlock(
            in_features=self.in_feature,
            out_features=self.out_feature,
            num_heads=2,
            dropout_prob=0.2,
            attention=False,
        )
        self.norm = nn.LayerNorm(self.out_feature)
        self.cross_attention = MultiHead_CrossAttention(
            infeature=self.out_feature,
            out_feature=self.out_feature,
            num_heads=2,
            dropout=0.2,
        )
        self.drop = nn.Dropout(self.dropout)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, encoder_out,text_style_content=None):

        
        char_embedding, global_net = self.TextStyle(text_style_content)
        char_upsampling = self.linear_upsampling(char_embedding)
        txt_style = global_net + char_upsampling.view(
            global_net.size(0),
            global_net.size(1),
            global_net.size(2),
            global_net.size(3),
        )
        print(f"{txt_style.shape=}")
        attetion_block, layer_norm = self.block_with_attention(
            txt_style.reshape(txt_style.size(0), -1)
        )
        norm_down_sample = self.linear_downsampling(layer_norm)
        norm_down_sample = norm_down_sample.repeat(
            attetion_block.size(0) // batch_size, 1
        )
        attention_norm = attetion_block + norm_down_sample
        block_without_attention, _ = self.block_without_attention(attention_norm)
        combained_without_attention = block_without_attention + attention_norm
        norm = self.norm(combained_without_attention)
        cross_attention = self.cross_attention(norm, encoder_out)
        drop_out = self.drop(cross_attention)
        norm = norm.repeat(drop_out.size(0) // (batch_size + batch_size), 1)
        combained_without_attention = drop_out + norm
        block_without_attention2, _ = self.block_without_attention(
            combained_without_attention
        )
        final_combained = block_without_attention2 + combained_without_attention

        soft_max = self.softmax(final_combained)
        return final_combained

class LayerNormLinearDropoutBlock(nn.Module):
    def __init__(
        self, in_features, out_features, num_heads, dropout_prob=0.1, attention=False
    ):
        super(LayerNormLinearDropoutBlock, self).__init__()
        self.attention = attention
        # Define the layer norm, linear layer, and dropout modules
        if self.attention:
            self.layer_norm = nn.LayerNorm(in_features)
            self.atten = MultiHeadAttention(
                in_features, out_features, num_heads, dropout_prob
            )
        else:
            print("attention is not applied")
            self.layer_norm = nn.LayerNorm(out_features)
            self.linear = nn.Linear(out_features, out_features)
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, x):
        # Apply layer norm to the input tensor
        layer_norm = self.layer_norm(x)
        if self.attention:
            x = self.atten(layer_norm)
        else:
            # Apply linear transformation to the input tensor

            print("attention is not applied")

            x = self.linear(layer_norm)

        # Apply dropout to the output of the linear layer
        x = self.dropout(x)

        return x, layer_norm


### Encoder

In [7]:
from torch import nn

# from torch.autograd import Variable
import numpy as np
import torch.functional as F

# from models.vgg_tro_channel1 import vgg16_bn


class Encoder(nn.Module):
    def __init__(self):
        super(Encoder,self).__init__()
        B, C, H, W = (
            batch_size,
            32,
            IMAGE_HEIGHT * scale_factor,
            IMAGE_WIDTH * scale_factor,
        )
        self.num_heads = 4
        head_size = 200
        print(f"channel:-{C=},Hight:- {H=} width:- {W=} Batch:- {B=}")
        self.in_feature = 32 * IMAGE_HEIGHT * scale_factor * IMAGE_WIDTH * scale_factor
        self.out_feature = 128
        self.resnet = Generator_Resnet(class_num=2, num_res_blocks=2).to(device)
        self.visual_encoder = Visual_encoder().to(device)  # vgg
        self.linear_downsampling = nn.Linear(
            in_features=self.in_feature, out_features=self.out_feature
        )
        self.block_with_attention = LayerNormLinearDropoutBlock(
            in_features=self.in_feature,
            out_features=self.out_feature,
            num_heads=2,
            dropout_prob=0.2,
            attention=True,
        )
        self.block_without_attention = LayerNormLinearDropoutBlock(
            in_features=self.in_feature,
            out_features=self.out_feature,
            num_heads=2,
            dropout_prob=0.2,
            attention=False,
        )
        self.norm = nn.LayerNorm(self.out_feature)

    def forward(self, x):
        resent = self.resnet(x.permute(1,0,2,3))  # resent   batch_size,outchannel,Hight , Width

        # resent=resent.view(batch_size,-1)
        visual_encder = self.visual_encoder(x)  # visual encoder for positionin
        # visual_encder=visual_encder.view(batch_size,-1)
        print(
            f"Shape of the resent output{resent.shape} and Vgg output shape{visual_encder.shape}"
        )
        combained_out = resent + visual_encder  # combained before input
        attention_block, norm_layer = self.block_with_attention(
            combained_out.view(combained_out.size(0), -1)
        )
        down_sampled_norm = self.linear_downsampling(norm_layer)
        down_sampled_norm = down_sampled_norm.repeat(
            attention_block.size(0) // batch_size, 1
        )
        combained_attention = down_sampled_norm + attention_block
        without_attention, _ = self.block_without_attention(combained_attention)
        combained_with_attention = combained_attention + without_attention
        final_norm = self.norm(combained_with_attention)
        print("End of encoder")
        return final_norm


In [8]:
import torch 
from torch import nn
class TextEncoder_FC(nn.Module):
    def __init__(self) -> None:
        super(TextEncoder_FC, self).__init__()
        """
         self.embed = Apply the embedding layer on the text tensor(2,85) -> (batch_size,max_text_len) -> out= (batch_size,max_len,embedding_size)
         xx = (batch_size, max_len_embedding_size)
         xxx = reshape the embedding output  from (batch_size,max_len_text,embedding_size) -> (batch_size,max_len*embedding_size) 
         out = Contained the output of the text style_network out_dim -> (batch_size,4096)

         xx_new =  apply the Linear layer on the embedding output 

        """
        self.embed = nn.Embedding(len(vocab), embedding_size)  # 81,64
        self.fc = nn.Sequential(
            nn.Flatten(),  # flatten the input tensor to a 1D tensor
            nn.Linear(text_max_len * embedding_size, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(inplace=False),
            nn.Linear(1024, 2048),
            nn.BatchNorm1d(2048),
            nn.ReLU(inplace=False),
            nn.Linear(2048, 5440),
        )
        self.linear = nn.Linear(
            embedding_size * text_max_len, embedding_size * text_max_len
        )  # 64,512
        self.linear1 = nn.Linear(embedding_size, embedding_size * text_max_len)

    def forward(self, x):
        """
        X: tensor of dim batch_size, max_text_len and embed_dim plz take other things will work accordingly 
        just take care of it. 
        
        """

        embedding = self.embed(x.squeeze(-1))  # b,t,embed

        batch_size = embedding.shape[0]
        xxx = embedding.reshape(batch_size, -1)  # b,t*embed
        out = self.fc(xxx)

        """embed content force"""
        xx_new = self.linear(embedding.view(2, -1)).view(
            embedding.size(0), embedding.size(1), embedding.size(2)
        )  # b, text_max_len, 512

        ts = xx_new.shape[1]  # b,512,8,27
        height_reps = IMAGE_HEIGHT  # 8 [-2]
        width_reps = max(1, IMAGE_WIDTH // ts)  # [-2] 27
        tensor_list = list()
        for i in range(ts):
            text = [xx_new[:, i : i + 1]]  # b, text_max_len, 512
            tmp = torch.cat(text * width_reps, dim=1)
            tensor_list.append(tmp)

        padding_reps = IMAGE_WIDTH % ts
        if padding_reps:
            embedded_padding_char = self.embed(torch.full((1, 1), 2, dtype=torch.long,device=device))
            # embedded_padding_char = self.linear1(embedded_padding_char)
            padding = embedded_padding_char.repeat(batch_size, padding_reps, 1)
            tensor_list.append(padding)

        res = torch.cat(
            tensor_list, dim=1
        )  # b, text_max_len * width_reps + padding_reps, 512
        res = res.permute(0, 2, 1).unsqueeze(
            2
        )  # b, 512, 1, text_max_len * width_reps + padding_reps
        final_res = torch.cat([res] * height_reps, dim=2)
        return out, final_res




In [9]:
batch_size
import json
from cv2 import imread,resize

### Data loading

In [10]:
class CustomImageDataset:
    def __init__(
        self, base_path="Single_Labels", img_dir=glob.glob("Line_data/Images/*/*/*"),
    ):

        self.base_path = base_path
        self.img_dir = img_dir

    def Load_Image_Label(self, image_path):
        # Open the image file
        label = tuple()
        json_path = os.path.join(
            self.base_path, image_path.split("\\")[-1][:-4] + ".json"
        )
        with open(json_path, "r") as json_file:
            label = json.load(json_file)
        img = imread(image_path, 0)
        img = 255 - img
        img_height, img_width = img.shape[0], img.shape[1]
        n_repeats = int(np.ceil(IMAGE_WIDTH / img_width))
        padded_image = np.concatenate([img] * n_repeats, axis=1)
        padded_image = padded_image[:IMAGE_HEIGHT, :IMAGE_WIDTH]
        resized_img = resize(padded_image, (IMAGE_WIDTH, IMAGE_HEIGHT))
        return (resized_img, label)
        # plt.imshow(img)
        # plt.show()

    def __len__(self):
        return len(self.img_dir)

    def __getitem__(self, idx):
        # import pdb;pdb.set_trace()
        Image, Labels = self.Load_Image_Label(self.img_dir[idx])
        return torch.tensor(Image, device=device).float(), Labels
        # return Image,Labels


In [11]:
from torch.utils.data import DataLoader, random_split
TextDatasetObj = CustomImageDataset()
TextDatasetObj = CustomImageDataset()
train_ratio = 0.8
test_ratio = 1 - train_ratio

    # Calculate the sizes of train and test sets based on the split ratios
train_size = int(train_ratio * len(TextDatasetObj))
test_size = len(TextDatasetObj) - train_size

    # Split the dataset into train and test sets
train_set, test_set = random_split(TextDatasetObj, [train_size, test_size])

    # Define batch size and number of workers for DataLoader
num_workers = 5

    # Create DataLoader instances for train and test sets
train_loader = DataLoader(train_set, batch_size=batch_size, num_workers=num_workers)
test_loader = DataLoader(test_set, batch_size=batch_size, num_workers=num_workers)
if len(train_loader) | len(test_loader)==0:
    print("Data isn't loaded properly")
else:
    print(f"{len(train_loader)=}     {len(test_loader)}=")

len(train_loader)=5341     1336=


In [None]:
train=iter(train_loader)

In [None]:
next(train)

In [None]:
visual_encoder = Visual_encoder().to(device)  # vgg


In [None]:
visual_encoder