# xml -> annotation

In [20]:
import xml.etree.ElementTree as ET

def load_xml_data(file_path: str):
    """
    xml data 불러오기
    """
    try:
        tree = ET.parse(file_path)  # XML 파일을 파싱
        root = tree.getroot()
        return root
    except ET.ParseError as e:
        print(f"XML 파일을 파싱하는 동안 오류가 발생했습니다: {e}")
        return None

In [21]:
DATA_PATH=f"../data"
MODEL_PATH=f"../model"
IMAGE_PATH=f"../images/"

DATA_FEATURE_PATH=f"{DATA_PATH}/processed-feature"
DATA_RAW_PATH=f"{DATA_PATH}/raw"
DATA_TEST_PATH=f"{DATA_PATH}/test"

OSMD="osmd-dataset-v1.0.0"


xml_path = f'{DATA_RAW_PATH}/{OSMD}/Rock-ver/Rock-ver.xml'

In [22]:
import xml.etree.ElementTree as ET

"""
- clef-purcussion 추출
- time-signature 추출
- multiple pitch 추출
    <chord/> <-  얘 있으면 동시에 친 거임
    <unpitched>
        <display-step>A</display-step>
        <display-octave>5</display-octave>
    </unpitched>
- 쉼표 추출
    <note>
        <rest/>
        <duration>48</duration>
        <type>quarter</type>
    </note>

output : ["clef-F4+keySignature-CM+note-E3_eighth.|note-C4_eighth.+note-E3_sixteenth|...", "..."]
"""

# 0. stave 마다 새로운 string 생성: print new-system 일 때마다 
# 1. clef-purcussion 삽입: xml에선 맨 처음에만 나오니까 매번 삽입
# 2. time-signature 있으면 삽입
# 3. note 삽입: pitch, duration
#    rest 삽입: duration
# 4. if 동시에 나온 note일 시, | 으로 구분
# 5. else + 로 연결

# MusicXML 파일을 파싱하여 ElementTree 객체 생성
tree = ET.parse(xml_path)
root = tree.getroot()
divisions_element = root.find(".//divisions")
if divisions_element is not None:
    divisions_value = int(divisions_element.text)
division = divisions_value
# 각 stave에 대한 문자열을 저장할 리스트
stave_strings = []


# 각 measure에 대한 처리 함수
def process_measure(measure):
    stave_string = ""

    for element in measure:
        if element.tag == 'attributes':
            # # clef 처리
            # clef = element.find('clef')
            # if clef is not None:
            #     stave_string += f"clef-{clef.find('sign').text}"

            # time signature 처리
            time = element.find('time')
            if time is not None:
                stave_string += f"timeSignature-{time.find('beats').text}/{time.find('beat-type').text}+"

        elif element.tag == 'note':
            if element.find('rest') is not None:
                stave_string += "rest"
            elif element.find('unpitched') is not None:
                # note 정보 처리
                pitch = element.find('unpitched')
                stave_string += f"note-{pitch.find('display-step').text}{pitch.find('display-octave').text}"
            
            rhythm = element.find('type').text
            # 16th -> sixteenth, 32nd → thirty_second
            if rhythm=="16th":
                rhythm="sixteenth"
            elif rhythm=="32nd":
                rhythm="thirty_second"
            stave_string += f"_{rhythm}+"

            if element.find("grace") is None:
                # 0.25 0.375 0.5 0.75 1.0 1.5 2.0 3.0 4.0 중에서
                # 0.375 0.75 1.5 3 이면 뒤에 . 붙이기
                duration_value = element.find('duration').text
                duration = float(duration_value) / float(division)
                d_ = '.' if duration in [0.375, 0.75, 1.5, 3]  else ""
                stave_string = stave_string[:-1]+f"{d_}+"

            # chord 여부에 따라 | 또는 + 추가
            if element.find('chord') is not None:
                # clef-percussion+note-F4_quarter+(<- 이걸 '|'로) note-A5_quarter+
                # note-G5_eighth+note-C5_eighth+note-G5_eighth+ 에서 끝의 + 빼고 reverse 한 후 
                # replace('+', '|', 1) 1개의 +만 |로 replace 한 후
                # 다시 reverse 후, 뒤에 + 붙여서 완성!
                components = stave_string[:-1][::-1]
                components = components.replace('+', '|', 1)
                stave_string = components[::-1]+'+'
                
    return stave_string

# # measure 태그를 가진 모든 element에 대해 처리
# stave_tmp="clef-percussion+"
# for measure in root.findall('.//measure'):
#     # print 태그 확인하여 new-system이 있는 경우에는 새로운 문자열로 시작
#     if measure.find('print') is not None and measure.find('print').get('new-system') == 'yes':
#         stave_strings.append(stave_tmp[:-1])
#         stave_tmp="clef-percussion+"

#     stave_tmp+=process_measure(measure)+"barline+"

print(stave_strings)

[]


# 학습!!!

In [23]:
import torch
import torch.nn as nn

from timm.models.vision_transformer import VisionTransformer
from timm.models.vision_transformer_hybrid import HybridEmbed
from timm.models.resnetv2 import ResNetV2
from timm.models.layers import StdConv2dSame
from einops import repeat


class CustomVisionTransformer(VisionTransformer):
    def __init__(self, img_size, patch_size=16, *args, **kwargs):
        super(CustomVisionTransformer, self).__init__(
            img_size=img_size, patch_size=patch_size, *args, **kwargs
        )
        self.height, self.width = img_size
        self.patch_size = patch_size

    def forward_features(self, x):
        B, c, h, w = x.shape
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        h, w = h // self.patch_size, w // self.patch_size
        pos_emb_ind = repeat(
            torch.arange(h) * (self.width // self.patch_size - w), "h -> (h w)", w=w
        ) + torch.arange(h * w)
        pos_emb_ind = torch.cat((torch.zeros(1), pos_emb_ind + 1), dim=0).long()
        x += self.pos_embed[:, pos_emb_ind]
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)
        return x


def get_encoder(args):
    backbone_layers = list(args.backbone_layers)
    backbone = ResNetV2(
        layers=backbone_layers,
        num_classes=0,
        global_pool="",
        in_chans=args.channels,
        preact=False,
        stem_type="same",
        conv_layer=StdConv2dSame,
    )
    min_patch_size = 2 ** (len(backbone_layers) + 1)

    def embed_layer(**x):
        ps = x.pop("patch_size", min_patch_size)
        assert ps % min_patch_size == 0 and ps >= min_patch_size, (
            "patch_size needs to be multiple of %i with current backbone configuration"
            % min_patch_size
        )
        return HybridEmbed(**x, patch_size=ps // min_patch_size, backbone=backbone)

    encoder = CustomVisionTransformer(
        img_size=(args.max_height, args.max_width),
        patch_size=args.patch_size,
        in_chans=args.channels,
        num_classes=0,
        embed_dim=args.encoder_dim,
        depth=args.encoder_depth,
        num_heads=args.encoder_heads,
        embed_layer=embed_layer,
        global_pool="",
    )
    return encoder


In [24]:
from math import ceil

import torch
import torch.nn as nn
import torch.nn.functional as F
from x_transformers.x_transformers import (
    AttentionLayers,
    TokenEmbedding,
    AbsolutePositionalEmbedding,
    Decoder,
)


class ScoreTransformerWrapper(nn.Module):
    def __init__(
        self,
        num_note_tokens,
        num_rhythm_tokens,
        num_pitch_tokens,
        # num_lift_tokens,
        max_seq_len,
        attn_layers,
        emb_dim,
        l2norm_embed=False,
    ):
        super().__init__()
        assert isinstance(
            attn_layers, AttentionLayers
        ), "attention layers must be one of Encoder or Decoder"

        dim = attn_layers.dim
        self.max_seq_len = max_seq_len
        self.l2norm_embed = l2norm_embed
        # self.lift_emb = TokenEmbedding(emb_dim, num_lift_tokens, l2norm_embed = l2norm_embed)
        self.pitch_emb = TokenEmbedding(
            emb_dim, num_pitch_tokens, l2norm_embed=l2norm_embed
        )
        self.rhythm_emb = TokenEmbedding(
            emb_dim, num_rhythm_tokens, l2norm_embed=l2norm_embed
        )
        self.pos_emb = AbsolutePositionalEmbedding(
            emb_dim, max_seq_len, l2norm_embed=l2norm_embed
        )

        self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
        self.attn_layers = attn_layers
        self.norm = nn.LayerNorm(dim)
        self.init_()

        # self.to_logits_lift = nn.Linear(dim, num_lift_tokens)
        self.to_logits_pitch = nn.Linear(dim, num_pitch_tokens)
        self.to_logits_rhythm = nn.Linear(dim, num_rhythm_tokens)
        self.to_logits_note = nn.Linear(dim, num_note_tokens)

    def init_(self):
        if self.l2norm_embed:
            # nn.init.normal_(self.lift_emb.emb.weight, std=1e-5)
            nn.init.normal_(self.pitch_emb.emb.weight, std=1e-5)
            nn.init.normal_(self.rhythm_emb.emb.weight, std=1e-5)
            nn.init.normal_(self.pos_emb.emb.weight, std=1e-5)
            return

        # nn.init.kaiming_normal_(self.lift_emb.emb.weight)
        nn.init.kaiming_normal_(self.pitch_emb.emb.weight)
        nn.init.kaiming_normal_(self.rhythm_emb.emb.weight)

    def forward(self, rhythms, pitchs, mask=None, return_hiddens=True, **kwargs):
        x = (
            self.rhythm_emb(rhythms)
            + self.pitch_emb(pitchs)
            # + self.lift_emb(lifts)
            + self.pos_emb(rhythms)
        )
        x = self.project_emb(x)
        x, hiddens = self.attn_layers(
            x, mask=mask, return_hiddens=return_hiddens, **kwargs
        )
        select_hiddens = hiddens[0][3]

        x = self.norm(x)

        # out_lifts = self.to_logits_lift(x)
        out_pitchs = self.to_logits_pitch(x)
        out_rhythms = self.to_logits_rhythm(x)
        out_notes = self.to_logits_note(x)
        # return out_rhythms, out_pitchs, out_lifts, out_notes, x
        return out_rhythms, out_pitchs, out_notes, x


def top_k(logits, thres=0.9):
    k = ceil((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float("-inf"))
    probs.scatter_(1, ind, val)
    return probs


class ScoreDecoder(nn.Module):
    def __init__(
        self, transoformer, noteindexes, num_rhythmtoken, ignore_index=-100, pad_value=0
    ):
        super().__init__()
        self.pad_value = pad_value
        self.ignore_index = ignore_index

        self.net = transoformer
        self.max_seq_len = transoformer.max_seq_len

        note_mask = torch.zeros(num_rhythmtoken)
        note_mask[noteindexes] = 1
        self.note_mask = nn.Parameter(note_mask)

    @torch.no_grad()
    def generate(
        self,
        start_tokens,
        nonote_tokens,
        seq_len,
        eos_token=None,
        temperature=1.0,
        filter_thres=0.9,
        min_p_pow=2.0,
        min_p_ratio=0.02,
        **kwargs
    ):
        device = start_tokens.device
        was_training = self.net.training
        num_dims = len(start_tokens.shape)

        if num_dims == 1:
            start_tokens = start_tokens[None, :]

        b, t = start_tokens.shape

        self.net.eval()
        out_rhythm = start_tokens
        out_pitch = nonote_tokens
        out_lift = nonote_tokens
        mask = kwargs.pop("mask", None)

        if mask is None:
            mask = torch.full_like(
                out_rhythm, True, dtype=torch.bool, device=out_rhythm.device
            )

        for _ in range(seq_len):
            mask = mask[:, -self.max_seq_len :]
            # x_lift = out_lift[:, -self.max_seq_len :]
            x_pitch = out_pitch[:, -self.max_seq_len :]
            x_rhymthm = out_rhythm[:, -self.max_seq_len :]

            rhythmsp, pitchsp, notesp, _ = self.net(
                x_rhymthm, x_pitch, mask=mask, **kwargs
            )

            # filtered_lift_logits = top_k(liftsp[:, -1, :], thres=filter_thres)
            filtered_pitch_logits = top_k(pitchsp[:, -1, :], thres=filter_thres)
            filtered_rhythm_logits = top_k(rhythmsp[:, -1, :], thres=filter_thres)

            # lift_probs = F.softmax(filtered_lift_logits / temperature, dim=-1)
            pitch_probs = F.softmax(filtered_pitch_logits / temperature, dim=-1)
            rhythm_probs = F.softmax(filtered_rhythm_logits / temperature, dim=-1)

            # lift_sample = torch.multinomial(lift_probs, 1)
            pitch_sample = torch.multinomial(pitch_probs, 1)
            rhythm_sample = torch.multinomial(rhythm_probs, 1)

            # out_lift = torch.cat((out_lift, lift_sample), dim=-1)
            out_pitch = torch.cat((out_pitch, pitch_sample), dim=-1)
            out_rhythm = torch.cat((out_rhythm, rhythm_sample), dim=-1)
            mask = F.pad(mask, (0, 1), value=True)

            if (
                eos_token is not None
                and (torch.cumsum(out_rhythm == eos_token, 1)[:, -1] >= 1).all()
            ):
                break

        # out_lift = out_lift[:, t:]
        out_pitch = out_pitch[:, t:]
        out_rhythm = out_rhythm[:, t:]

        if num_dims == 1:
            out = out.squeeze(0)

        self.net.train(was_training)
        return out_rhythm, out_pitch

    def forward(self, rhythms, pitchs, notes, **kwargs):
        # liftsi = lifts[:, :-1]
        # liftso = lifts[:, 1:]
        pitchsi = pitchs[:, :-1]
        pitchso = pitchs[:, 1:]
        rhythmsi = rhythms[:, :-1]
        rhythmso = rhythms[:, 1:]
        noteso = notes[:, 1:]

        mask = kwargs.get("mask", None)
        if mask is not None and mask.shape[1] == rhythms.shape[1]:
            mask = mask[:, :-1]
            kwargs["mask"] = mask

        rhythmsp, pitchsp, notesp, x = self.net(rhythmsi, pitchsi, **kwargs)

        loss_consist = self.calConsistencyLoss(rhythmsp, pitchsp, notesp)
        loss_rhythm = F.cross_entropy(
            rhythmsp.transpose(1, 2), rhythmso, ignore_index=self.ignore_index
        )
        loss_pitch = F.cross_entropy(
            pitchsp.transpose(1, 2), pitchso, ignore_index=self.ignore_index
        )
        # loss_lift = F.cross_entropy(
        #     liftsp.transpose(1, 2), liftso, ignore_index=self.ignore_index
        # )
        loss_note = F.cross_entropy(
            notesp.transpose(1, 2), noteso, ignore_index=self.ignore_index
        )

        return dict(
            loss_rhythm=loss_rhythm,
            loss_pitch=loss_pitch,
            # loss_lift=loss_lift,
            loss_consist=loss_consist,
            loss_note=loss_note,
        )

    def calConsistencyLoss(self, rhythmsp, pitchsp, notesp, gamma=10):
        notesp_soft = torch.softmax(notesp, dim=2)
        note_flag = notesp_soft[:, :, 1]
        rhythmsp_soft = torch.softmax(rhythmsp, dim=2)
        rhythmsp_note = torch.sum(rhythmsp_soft * self.note_mask, dim=2)

        pitchsp_soft = torch.softmax(pitchsp, dim=2)
        pitchsp_note = torch.sum(pitchsp_soft[:, :, 1:], dim=2)

        # liftsp_soft = torch.softmax(liftsp, dim=2)
        # liftsp_note = torch.sum(liftsp_soft[:, :, 1:], dim=2)

        loss = (
            gamma
            * (
                F.l1_loss(rhythmsp_note, note_flag)
                # + F.l1_loss(note_flag, liftsp_note)
                + F.l1_loss(note_flag, pitchsp_note)
            )
            / 3.0
        )
        return loss


def get_decoder(args):
    return ScoreDecoder(
        ScoreTransformerWrapper(
            num_note_tokens=args.num_note_tokens,
            num_rhythm_tokens=args.num_rhythm_tokens,
            num_pitch_tokens=args.num_pitch_tokens,
            # num_lift_tokens=args.num_lift_tokens,
            max_seq_len=args.max_seq_len,
            emb_dim=args.decoder_dim,
            attn_layers=Decoder(
                dim=args.decoder_dim,
                depth=args.decoder_depth,
                heads=args.decoder_heads,
                **args.decoder_args
            ),
        ),
        pad_value=args.pad_token,
        num_rhythmtoken=args.num_rhythmtoken,
        noteindexes=args.noteindexes,
    )


In [25]:
import torch
import torch.nn as nn

class TrOMR(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.encoder = get_encoder(args)
        self.decoder = get_decoder(args)
        self.args = args

    def forward(self, inputs, rhythms_seq, pitchs_seq, note_seq, mask, **kwargs):
        encoded = self.encoder(inputs)
        loss = self.decoder(
            rhythms_seq, pitchs_seq, note_seq, context=encoded, mask=mask, **kwargs
        )
        return loss

    @torch.no_grad()
    def generate(self, x: torch.Tensor, temperature: float = 0.25):
        start_token = (torch.LongTensor([self.args.bos_token] * len(x))[:, None]).to(
            x.device
        )
        nonote_token = (
            torch.LongTensor([self.args.nonote_token] * len(x))[:, None]
        ).to(x.device)

        out_pitch, out_rhythm = self.decoder.generate(
            start_token,
            nonote_token,
            self.args.max_seq_len,
            eos_token=self.args.eos_token,
            context=self.encoder(x),
            temperature=temperature,
        )

        return out_pitch, out_rhythm


In [26]:
import os

import cv2
import torch
import numpy as np
import albumentations as alb
from albumentations.pytorch import ToTensorV2

from transformers import PreTrainedTokenizerFast
from einops import rearrange, reduce, repeat

# from model import TrOMR


class StaffToScore(object):
    def __init__(self, args):
        self.args = args
        self.size_h = args.max_height
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = TrOMR(args)
        # self.model.load_state_dict(torch.load(args.filepaths.checkpoint), strict=True)
        self.model.to(self.device)

        # self.lifttokenizer = PreTrainedTokenizerFast(
        #     tokenizer_file=args.filepaths.lifttokenizer
        # )
        self.pitchtokenizer = PreTrainedTokenizerFast(
            tokenizer_file=args.filepaths.pitchtokenizer
        )
        self.rhythmtokenizer = PreTrainedTokenizerFast(
            tokenizer_file=args.filepaths.rhythmtokenizer
        )
        self.transform = alb.Compose(
            [
                alb.ToGray(always_apply=True),
                alb.Normalize((0.7931, 0.7931, 0.7931), (0.1738, 0.1738, 0.1738)),
                ToTensorV2(),
            ]
        )

    def readimg(self, path):
        # img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
        img = cv2.imread(path)
        print(f"1 -- resize 전")
        print(img.shape)

        if img.shape[-1] == 4:
            img = 255 - img[:, :, 3]
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
        elif img.shape[-1] == 3:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        else:
            raise RuntimeError("Unsupport image type!")

        h, w, c = img.shape
        new_h = self.size_h
        new_w = int(self.size_h / h * w)
        new_w = new_w // self.args.patch_size * self.args.patch_size
        img = cv2.resize(img, (new_w, new_h))
        img = self.transform(image=img)["image"][:1]

        print(f"2 -- resize 후")
        print(img.shape)
        print(img.dtype)
        return img

    def preprocessing(self, rgb):
        patches = rearrange(
            rgb,
            "b c (h s1) (w s2) -> b (h w) (s1 s2 c)",
            s1=self.args.patch_size,
            s2=self.args.patch_size,
        )
        return patches

    def train_img2token(self, x, y):
        if not isinstance(x, list):
            x = [x]
        imgs = [self.preprocessing(item) for item in x]
        imgs = torch.cat(imgs).float().unsqueeze(1)


    def detokenize(self, tokens, tokenizer):
        toks = [tokenizer.convert_ids_to_tokens(tok) for tok in tokens]
        for b in range(len(toks)):
            for i in reversed(range(len(toks[b]))):
                if toks[b][i] is None:
                    toks[b][i] = ""
                toks[b][i] = toks[b][i].replace("Ġ", " ").strip()
                if toks[b][i] in (["[BOS]", "[EOS]", "[PAD]"]):
                    del toks[b][i]
        return toks
    def entokenize(self, tokens, tokenizer):
        toks = [tokenizer.convert_tokens_to_ids(tokenizer.tokenize(tok)) for tok in tokens]
        return toks

    def all_entokenize(self, rhythm_y_list, pitch_y_list):
        token_rhythm = self.entokenize(rhythm_y_list, self.rhythmtokenizer)
        token_pitch = self.entokenize(pitch_y_list, self.pitchtokenizer)
        return token_rhythm, token_pitch
    

    def train_model(self, inputs, rhythms_seq, pitchs_seq):
        input = inputs.to(self.device)
        
        return self.model.forward(input, rhythms_seq, pitchs_seq, [], [])


    # def predict_img2token(self, rgbimgs):
    #     if not isinstance(rgbimgs, list):
    #         rgbimgs = [rgbimgs]
    #     imgs = [self.preprocessing(item) for item in rgbimgs]
    #     imgs = torch.cat(imgs).float().unsqueeze(1)
    #     output = self.model.generate(
    #         imgs.to(self.device), temperature=self.args.get("temperature", 0.2)
    #     )
    #     rhythm, pitch, lift = output
    #     return rhythm, pitch, lift

    # def predict_token(self, imgpath):
    #     imgs = []
    #     if os.path.isdir(imgpath):
    #         for item in os.listdir(imgpath):
    #             imgs.append(self.readimg(os.path.join(imgpath, item)))
    #     else:
    #         imgs.append(self.readimg(imgpath))
    #     imgs = torch.cat(imgs).float().unsqueeze(1)
    #     output = self.model.generate(
    #         imgs.to(self.device), temperature=self.args.get("temperature", 0.2)
    #     )
    #     rhythm, pitch, lift = output
    #     return rhythm, pitch, lift

    # def predict(self, imgpath):
    #     rhythm, pitch, lift = self.predict_token(imgpath)

    #     predlift = self.detokenize(lift, self.lifttokenizer)
    #     predpitch = self.detokenize(pitch, self.pitchtokenizer)
    #     predrhythm = self.detokenize(rhythm, self.rhythmtokenizer)
    #     return predrhythm, predpitch, predlift



In [27]:
import os
from omegaconf import OmegaConf


def getconfig(configpath):
    args = OmegaConf.load(configpath)

    workspace = os.path.dirname(configpath)
    for key in args.filepaths.keys():
        args.filepaths[key] = os.path.join(workspace, args.filepaths[key])
    return args


In [28]:
import glob
import os

import argparse
from random import randrange

import cv2
import numpy as np
import pandas as pd
import torch


if __name__ == "__main__":
    # --
    # parser = argparse.ArgumentParser(description="Inference single staff image")
    # parser.add_argument("filepath", type=str, help="path to staff image")
    # parsed_args = parser.parse_args()

    # os.path.dirname(__file__): 현재 실행 중인 스크립트의 파일 경로
    cofigpath = "../src/workspace/config.yaml" 
    args = getconfig(cofigpath)

    handler = StaffToScore(args)

    x_dataset_path=f"{DATA_FEATURE_PATH}/transformer/Rock-ver/stave/"
    x_all_dataset_path = glob.glob(f"{x_dataset_path}/*")
    x_file_list = [file for file in x_all_dataset_path if file.endswith(f".png")]
    x_file_list.sort()

    # imgpath = f"{DATA_TEST_PATH}/test-01.png"

    y_dataset_path=f"{DATA_FEATURE_PATH}/transformer/Rock-ver/annotation/"
    y_all_dataset_path = glob.glob(f"{y_dataset_path}/*")
    y_file_list = [file for file in y_all_dataset_path if file.endswith(f".txt")]
    y_file_list.sort()
    

    def convert_img(imgpath):
        imgs = []
        if os.path.isdir(imgpath):
            for item in os.listdir(imgpath):
                imgs.append(handler.readimg(os.path.join(imgpath, item)))
        else:
            imgs.append(handler.readimg(imgpath))
        imgs = torch.cat(imgs).float()

        return imgs

    """
    1 -- resize 전
    (298, 2404, 4)
    2 -- resize 후
    torch.Size([1, 128, 1024])
    torch.float32
    rgbimgs : torch.Size([1, 1, 128, 1024])
    """

    x_list=[]
    for idx, x_file in enumerate(x_file_list):
        convert_x_file = convert_img(x_file)
        print(f"{idx+1}:{x_file}\n -- rgbimgs : {convert_x_file.shape}")
        x_list.append(convert_x_file)

    concatenated_x_list = torch.cat(x_list, dim=0)

    # 얘가 encoder input임!!!!!!!!!! concatenated_x_list.unsqueeze(1)
    inputs = concatenated_x_list.unsqueeze(1)

    # imgs = handler.preprocessing(concatenated_x_list.unsqueeze(1))
    # print(f"------------------preprocessing---------------------")
    # print(f"-- preprocessing x : {imgs.shape}")
    # print(f"----------------------------------------------------")


    def read_txt_file(file_path):
        """
        텍스트 파일을 읽어서 내용을 리스트로 반환하는 함수
        """
        with open(file_path, 'r', encoding='utf-8') as file:
            content = file.readlines()
            # 각 줄의 개행 문자 제거
            content = [line.strip() for line in content]
        return content[0]
    # 각 파일의 내용을 담을 리스트
    contents = []

    # 각 파일을 읽어서 내용을 리스트에 추가
    for txt_path in y_file_list:
        # print("--- txt_path:", txt_path)
        content = read_txt_file(txt_path)
        # print(content)
        contents.append(content)

    print(f"------------------y data----------------------------")
    print(f"-- labeling y : {len(contents)}")
    print(f"----------------------------------------------------")

    # 각 token에 맞는 string list로 만들기

    def map_pitch(note):
        pitch_mapping = {
            "note-D4": 1,
            "note-E4": 2,
            "note-F4": 3,
            "note-G4": 4,
            "note-A4": 5,
            "note-B4": 6,
            "note-C5": 7,
            "note-D5": 8,
            "note-E5": 9,
            "note-F5": 10,
            "note-G5": 11,
            "note-A5": 12,
            "note-B5": 13
        }
        return "nonote" if note not in pitch_mapping else note
    
    def map_duration(note):
        duration_mapping =  {
            "[PAD]": 0,
            "[BOS]": 1,
            "[EOS]": 2,
            "+": 3,
            "|": 4,
            "barline": 5,
            "clef-percussion":6,

            "note-eighth": 7,
            "note-eighth.": 8,
            "note-half": 9,
            "note-half.": 10,
            "note-quarter": 11,
            "note-quarter.": 12,
            "note-sixteenth": 13,
            "note-sixteenth.": 14,
            "note-thirty_second": 15,
            "note-thirty_second.": 16,
            "note-whole": 17,
            "note-whole.": 18,

            "rest-eighth": 19,
            "rest-eighth.": 20,
            "rest-half": 21,
            "rest-half.": 22,
            "rest-quarter": 23,
            "rest-quarter.": 24,
            "rest-sixteenth": 25,
            "rest-sixteenth.": 26,
            "rest-thirty_second": 27,
            "rest-thirty_second.": 28,
            "rest-whole": 29,
            "rest-whole.": 30,

            "timeSignature-4/4": 31,
        }
        return note if note in duration_mapping else ""

    def map_notes2pitch(note_list):
        result=[]
        for notes in note_list:
            # print(notes)
            group_notes = []
            # 우선 +로 나누고, 안에 | 있는 지 확인해서 먼저 붙이기
            note_split = notes.split("+")
            # print(note_split)
            i_idx=0
            while i_idx<len(note_split):
                note_s = note_split[i_idx]
                if "|" in note_s:
                    mapped_note_chord = []
                    note_split_chord = note_s.split("|")
                    
                    # | 로 나눈 만큼 건너뛰기
                    for note_s_c in note_split_chord:
                        note_pitch, _ = note_s_c.split("_")
                        mapped_note_chord.append(f"{map_pitch(note_pitch)}")
                    group_notes.append("+".join(mapped_note_chord))
                    i_idx+=1
                elif "note" in note_s:
                    if "_" in note_s:
                        note_pitch, _ = note_s.split("_")
                        group_notes.append(f"{map_pitch(note_pitch)}")
                        i_idx+=1
                else:
                    group_notes.append(f"{map_pitch(note_s)}")
                    i_idx+=1
            # print(group_notes)
            result.append("+".join(group_notes))
        return result
    
    def map_notes2rhythm(note_list):
        result=[]
        for notes in note_list:
            group_notes = ["[BOS]"]
            # 우선 +로 나누고, 안에 | 있는 지 확인해서 먼저 붙이기
            note_split = notes.split("+")
            i_idx=0
            while i_idx<len(note_split):
                note_s = note_split[i_idx]
                if "|" in note_s:
                    mapped_note_chord = []
                    note_split_chord = note_s.split("|")
                    
                    # | 로 나눈 만큼 건너뛰기
                    for note_s_c in note_split_chord:
                        # note-thirty_second 가 있기 때문에, 이걸 어떻게 잇지? _ 로 split해버리면 thirty second로 분리되자나
                        # 분리된 거에서 첫 번째 pitch 정보는 버리고, _로 join 해놓고 
                        note_s_c_split = note_s_c.split("_")
                        note_s_c_split = note_s_c_split[1:]
                        note_duration = "_".join(note_s_c_split)

                        mapped_note_chord.append(map_duration(f"note-{note_duration}"))
                    group_notes.append("|".join(mapped_note_chord))
                    i_idx+=1
                elif "note" in note_s:
                    if "_" in note_s:
                        note_s_c_split = note_s_c.split("_")
                        note_s_c_split = note_s_c_split[1:]
                        note_duration = "_".join(note_s_c_split)

                        group_notes.append(map_duration(f"note-{note_duration}"))
                        i_idx+=1
                elif "rest" in note_s:
                    if "_" in note_s:
                        note_s_c_split = note_s_c.split("_")
                        note_s_c_split = note_s_c_split[1:]
                        note_duration = "_".join(note_s_c_split)

                        group_notes.append(map_duration(f"rest-{note_duration}"))
                        i_idx+=1
                else:
                    group_notes.append(map_duration(note_s))
                    i_idx+=1
            # print(group_notes)
            group_notes.append("[EOS]")
            result.append("+".join(group_notes))
            print("+".join(group_notes))

        return result

    rhythm_contents = map_notes2rhythm(contents)
    pitch_contents = map_notes2pitch(contents)
    token_rhythm, token_pitch = handler.all_entokenize(rhythm_contents, pitch_contents)
    print(f"------------------processing y data----------------------------")
    print(f"-- token_rhythm : {len(token_rhythm)}")
    print(f"-- token_rhythm : {token_rhythm}")
    print(f"-- token_pitch : {len(token_pitch)}")
    print(f"-- token_pitch : {token_pitch}")
    print(f"----------------------------------------------------")


    model = handler.train_model(inputs, token_rhythm, token_pitch)


1 -- resize 전
(120, 1000, 3)
2 -- resize 후
torch.Size([1, 128, 1056])
torch.float32
1:../data/processed-feature/transformer/Rock-ver/stave/Rock-ver_pad-stave-origin_10_2024-05-04_16-30-26.png
 -- rgbimgs : torch.Size([1, 128, 1056])
1 -- resize 전
(120, 1000, 3)
2 -- resize 후
torch.Size([1, 128, 1056])
torch.float32
2:../data/processed-feature/transformer/Rock-ver/stave/Rock-ver_pad-stave-origin_1_2024-05-04_16-30-26.png
 -- rgbimgs : torch.Size([1, 128, 1056])
1 -- resize 전
(120, 1000, 3)
2 -- resize 후
torch.Size([1, 128, 1056])
torch.float32
3:../data/processed-feature/transformer/Rock-ver/stave/Rock-ver_pad-stave-origin_2_2024-05-04_16-30-26.png
 -- rgbimgs : torch.Size([1, 128, 1056])
1 -- resize 전
(120, 1000, 3)
2 -- resize 후
torch.Size([1, 128, 1056])
torch.float32
4:../data/processed-feature/transformer/Rock-ver/stave/Rock-ver_pad-stave-origin_3_2024-05-04_16-30-26.png
 -- rgbimgs : torch.Size([1, 128, 1056])
1 -- resize 전
(120, 1000, 3)
2 -- resize 후
torch.Size([1, 128, 1056])
t

TypeError: Module.train() takes from 1 to 2 positional arguments but 6 were given