In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import os
import random

import accelerate
import torch
import transformers

from src._shared import (
    apply_lora_to_model,
    apply_peft_to_model,
    freeze_base_models,
    load_clip_model,
    load_config,
    load_tokenizers,
    prepare_dataset,
    save_model_and_logs,
    setup_environment,
    setup_trainer,
    train_model,
)

In [3]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

train_config = load_config()
model_name_identifier, device, report_to, run, USE_WANDB, SEED = setup_environment(train_config)

accelerate.utils.set_seed(SEED + 1)
transformers.set_seed(SEED + 2)
torch.manual_seed(SEED + 3)
random.seed(SEED + 4)

[34m[1mwandb[0m: Currently logged in as: [33mfinnlueth[0m ([33mfinnlueth-organization[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Using device: cuda
Model identifier: protT5-CLIP-2025-01-12-19-58-05


In [4]:
tokenizer_plm, tokenizer_llm = load_tokenizers(train_config)
dataset = prepare_dataset(train_config, tokenizer_plm, tokenizer_llm)

Loading dataset from disk...


In [5]:
print(dataset)
print(dataset["train"][0])

DatasetDict({
    train: Dataset({
        features: ['identifier', 'term', 'aspect', 'GO Name', 'GO Sentence', 'sequence', 'species', '__index_level_0__', 'sequence_processed', 'input_ids_sequence', 'attention_mask_sequence', 'input_ids_text', 'attention_mask_text'],
        num_rows: 44098
    })
    test: Dataset({
        features: ['identifier', 'term', 'aspect', 'GO Name', 'GO Sentence', 'sequence', 'species', '__index_level_0__', 'sequence_processed', 'input_ids_sequence', 'attention_mask_sequence', 'input_ids_text', 'attention_mask_text'],
        num_rows: 221346
    })
})
{'identifier': 'A0A023GUT0', 'term': 'GO:0042531', 'aspect': 'BPO', 'GO Name': 'positive regulation of tyrosine phosphorylation of STAT protein', 'GO Sentence': 'The biological process is positive regulation of tyrosine phosphorylation of STAT protein.', 'sequence': 'MRCPGVSLWGLLCLGAAAGGGRPVRLEGLRADARALTRTLSTRLQQLQLFPLTLRLSGLEGVPEGVPEGVPEGGVPPGLGWAAQRLQLFQRLLGALPGPDPRLAQVANDLENLRSLLALLGTLLGCPPPRDPRPPPPAPLAEA

In [6]:
model = load_clip_model(train_config, device)

`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.
Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded model...
All model parameters are on CUDA


In [7]:
train_config['lora']['enabled']

True

In [8]:
if train_config['lora']['enabled']:
    model = apply_lora_to_model(model, train_config)
else:
    model = apply_peft_to_model(model, train_config)

target_modules: ['q', 'v']
modules_to_save: ['protein_projection', 'text_projection', 'logit_scale']
trainable params: 6,160,385 || all params: 5,039,576,066 || trainable%: 0.1222


In [9]:
trainer = setup_trainer(model, dataset, train_config, model_name_identifier, USE_WANDB, tokenizer_plm, tokenizer_llm)

In [10]:
train_model(trainer)

You are not running the flash-attention implementation, expect numerical differences.


Step,Training Loss,Validation Loss,Mean Cosine Similarity,Std Cosine Similarity,Min Cosine Similarity,Max Cosine Similarity,All Similarities
0,No log,2.812042,-0.006935,0.02865,-0.08077,0.080968,"[-0.0035166405141353607, -0.04130862280726433, -0.000780944712460041, -0.031983643770217896, 0.013420842587947845, -0.003866031765937805, -0.01620946079492569, -0.005806703120470047, -0.0023584123700857162, -0.03496261686086655, 0.018445875495672226, -0.03353946655988693, 0.008779661729931831, -0.026682641357183456, -0.042406003922224045, -0.024650171399116516, -0.009432639926671982, 0.03835102915763855, -0.006106609012931585, -0.030385587364435196, 0.046013567596673965, 0.002277165651321411, 0.01096368208527565, 0.03229998052120209, 0.038726016879081726, -0.0022425632923841476, 0.0015308409929275513, -0.05811375752091408, -0.012393398210406303, -0.026377158239483833, 0.002449191175401211, 0.009680441580712795, 0.038030024617910385, -0.016191037371754646, 0.027334842830896378, -0.016037072986364365, -0.00454756710678339, -0.04120336472988129, -0.03241167962551117, -0.015479051508009434, -0.019272156059741974, 0.08096839487552643, -0.02162843942642212, -0.03645313158631325, 0.02642294391989708, -0.05106343328952789, 0.027701983228325844, 0.025350790470838547, 0.0310086477547884, -0.017373817041516304, 0.0027919113636016846, -0.015194120816886425, -0.03411514312028885, -0.02134641259908676, -0.021219804883003235, -0.037569593638181686, -0.03607146069407463, -0.0006603021174669266, -0.044190146028995514, -0.01486707292497158, -0.004071785137057304, -0.015853147953748703, 0.00569523312151432, -0.013443637639284134, 0.01075040828436613, -0.013473739847540855, -0.02137000858783722, -0.010724730789661407, 0.0034652319736778736, -0.005780436098575592, -0.025120384991168976, -0.019910283386707306, -0.0002989843487739563, -0.07220489531755447, 0.028723493218421936, -0.0729510635137558, 0.03555825352668762, -0.007537211291491985, 0.016677213832736015, -0.025576744228601456, 0.008100741542875767, -0.014573995023965836, 0.004173662513494492, -0.05386916548013687, -0.028938405215740204, -0.05407356470823288, 0.006007778458297253, -0.0052340589463710785, 0.00883195735514164, -0.026421431452035904, 0.07054607570171356, 0.01115869265049696, 0.02127894200384617, 0.020744217559695244, 0.015967974439263344, -0.007922408170998096, -0.02779102697968483, 0.009017374366521835, -0.007212817203253508, 0.009285982698202133, -0.004219105467200279, -0.02728991024196148, -0.006484813988208771, -0.02544514276087284, 0.022664153948426247, -0.007569001987576485, -0.02710754983127117, 0.0639718770980835, 0.06548230350017548, -0.034692782908678055, -0.008661539293825626, -0.04006943106651306, -0.039812542498111725, 0.024205835536122322, 0.055410195142030716, -0.016718169674277306, 0.011515711434185505, -0.02564706839621067, -0.027331463992595673, -0.03430574759840965, -0.040234293788671494, -0.0422171913087368, 0.016054252162575722, -0.03914959728717804, -0.025379203259944916, -0.004393904469907284, 0.00733895692974329, -0.040693383663892746, -0.03360693156719208, -0.044456060975790024, 0.011223231442272663, 0.016995515674352646, -0.04506484046578407, -0.049331024289131165, 0.020411478355526924, -0.037328317761421204, 0.008271075785160065, -0.03561434522271156, -0.0028811078518629074, -0.01927327737212181, 0.0175146646797657, 0.0067808437161147594, -0.01810886710882187, -0.0011753514409065247, -0.02557564713060856, -0.014831334352493286, -0.0012004300951957703, 0.01915668696165085, -0.028038660064339638, -0.0027948785573244095, -0.03897298127412796, -0.031534530222415924, 0.00773528590798378, -0.03011964075267315, 0.039729855954647064, -0.011810820549726486, 0.00652321195229888, -0.018795102834701538, -0.044616639614105225, 0.009287858381867409, 0.008082117885351181, -0.04313266649842262, -0.007967373356223106, 0.031479351222515106, 0.02713150531053543, -0.016314323991537094, -0.028636012226343155, 0.04727890342473984, 0.011601608246564865, 0.05897875130176544, -0.04226271063089371, -0.028052587062120438, 0.03392761945724487, 0.019658394157886505, 0.026902347803115845, -0.005765422247350216, 0.00955285131931305, 0.01273639127612114, -0.029992181807756424, 0.04223836585879326, -0.054655224084854126, -0.022576868534088135, -0.04775749146938324, -0.0060442546382546425, 0.028644509613513947, 0.009279467165470123, -0.007320278324186802, 0.002679976634681225, -0.003701872192323208, -0.050900284200906754, -0.044378068298101425, -0.028894536197185516, 0.038389354944229126, -0.0556928887963295, 0.02019364759325981, -0.04019720107316971, 0.02494431659579277, -0.013175195083022118, -0.027662256732583046, 0.04139895364642143, -0.035232558846473694, -0.017066635191440582, -0.010675124824047089, 0.029514696449041367, -0.07227686047554016, -0.009967450052499771, 0.013759428635239601, 0.0010207961313426495, -0.03558926284313202, -0.015301793813705444, 0.0019140318036079407, -0.0025446098297834396, -0.021819209679961205, 0.0055390577763319016, 0.0007956260815262794, 0.02494148351252079, -0.0699399784207344, -0.03965046629309654, -0.0012775575742125511, 0.0016405768692493439, -0.08077022433280945, -0.019708693027496338, 0.008533216081559658, 0.04041785001754761, 0.011400694958865643, 0.01592973619699478, -0.0002118367701768875, -0.023882340639829636, -0.006231994833797216, -0.01312408410012722, 0.03403007239103317, -0.0017402563244104385, 0.027850985527038574, -0.05274180322885513, 0.03964783623814583, 0.013518495485186577, 0.01746673323214054, -0.026844114065170288, 0.025352517142891884, 0.01693667098879814, -0.00613074004650116, -0.03344491496682167, -0.018903624266386032, 0.052244361490011215, -0.045233726501464844, -0.013961861841380596, 0.018258865922689438, 0.01168374065309763, 0.016568783670663834, -0.021492550149559975, -0.004532109946012497, -0.0755028948187828, 0.006027234718203545, -0.011655082926154137, 0.021285448223352432, -0.036070432513952255]"
32,1.001800,2.93783,-0.040089,0.059728,-0.19834,0.157292,"[0.006150547415018082, 0.006417357362806797, -0.03417697921395302, -0.14724381268024445, -0.06903433799743652, -0.06364580988883972, 0.037042565643787384, -0.054714761674404144, -0.056644562631845474, -0.10242101550102234, -0.08916041254997253, 0.042549487203359604, -0.13334506750106812, 0.0005814218893647194, -0.00516901072114706, -0.09341704100370407, 0.011992224492132664, -0.07431180030107498, -0.07076869159936905, -0.016994871199131012, -0.023069176822900772, 0.04321540519595146, -0.06417019665241241, -0.07363730669021606, -0.070689357817173, -0.0487787202000618, -0.0714716985821724, 0.006761023309081793, -0.0967433899641037, 0.03815479949116707, -0.009878375567495823, -0.10482454299926758, -0.08245857805013657, 0.04117643088102341, 0.04853445291519165, 0.004905343987047672, -0.06722507625818253, 0.009300803765654564, -0.19057075679302216, -0.037939418107271194, -0.00587083026766777, -0.08669696748256683, -0.06081641465425491, -0.10071061551570892, -0.036726318299770355, -0.08205758035182953, -0.03135204315185547, -0.07635936141014099, -0.016306709498167038, -0.05397000163793564, -0.04807721823453903, -0.07905798405408859, -0.053780000656843185, -0.0037616787012666464, 0.03178956359624863, -0.06844628602266312, -0.05531316250562668, -0.11216171085834503, -0.09241408854722977, -0.1320371925830841, -0.08205008506774902, -0.007653312757611275, -0.06292837113142014, 0.013886960223317146, -0.0377705842256546, -0.09532483667135239, -0.043891262263059616, -0.13317620754241943, -0.13043776154518127, -0.08024151623249054, -0.09067307412624359, 0.05273683741688728, 0.010902846232056618, -0.12396138906478882, 0.00016979104839265347, -0.08562162518501282, -0.037386488169431686, -0.0898321345448494, 0.002711697481572628, -0.05859014764428139, -0.06255213171243668, 0.05289943516254425, 0.08715076744556427, -0.1031029149889946, 0.033560268580913544, -0.15932485461235046, -0.0525931641459465, -0.07769189774990082, -0.03358675539493561, 0.01387110911309719, -0.012370769865810871, -0.03428954631090164, -0.13710449635982513, -0.023266376927495003, 0.0008131349459290504, -0.07392800599336624, 0.0018818373791873455, 0.0005362518131732941, -0.010538540780544281, -0.021193362772464752, 0.04658246040344238, -0.08578470349311829, -0.02411825582385063, -0.19833973050117493, -0.024906249716877937, -0.08060795068740845, 0.03661838173866272, -0.00928499735891819, 0.0557713583111763, 0.011626551859080791, -0.03407883644104004, -0.023715944960713387, -0.1648748815059662, 0.09237085282802582, -0.02514740079641342, -0.11301815509796143, -0.039006952196359634, 0.008353615179657936, 0.007710696198046207, -0.037805818021297455, -0.07004030048847198, -0.07478944957256317, -0.046683937311172485, -0.11693733185529709, -0.1187884658575058, 0.012520326301455498, 0.07347496598958969, -0.16741237044334412, -0.07545220851898193, -0.05995210260152817, -0.10925216972827911, -0.09416672587394714, -0.0019298712722957134, -0.0004917634651064873, -0.07526767253875732, -0.06608697772026062, 0.03562845662236214, 0.024838700890541077, -0.06419000029563904, 0.018672766163945198, -0.0422268770635128, -0.08947101980447769, 0.07007606327533722, -0.06660880893468857, -0.033875465393066406, 0.018478576093912125, -0.0974428579211235, -0.15826843678951263, 0.0874880999326706, -0.029772739857435226, -0.07040271162986755, -0.03844957426190376, 0.0004692976363003254, -0.14031025767326355, -0.006234588101506233, -0.04531615972518921, -0.004580494947731495, -0.11604609340429306, 0.05197928473353386, 0.08696531504392624, 0.035531267523765564, -0.002606830559670925, -0.0013103079982101917, -0.032940544188022614, -0.05186108499765396, -0.12400424480438232, -0.08600728213787079, -0.08134816586971283, -0.01100989617407322, -0.08576320856809616, 0.006526175886392593, 0.08842878043651581, -0.06010442227125168, -0.02412571758031845, -0.12549372017383575, -0.005503895226866007, 0.06801740825176239, 0.019829416647553444, -0.11785230040550232, -0.038742486387491226, -0.044765643775463104, -0.030006948858499527, -0.01278829574584961, 0.0188322514295578, -0.057800471782684326, 0.036929838359355927, -0.11417032033205032, -0.05466126650571823, 0.010275611653923988, -0.13371120393276215, -0.03325055539608002, -0.06894341111183167, 0.052708834409713745, -0.0778891071677208, -0.08829745650291443, -0.0851738452911377, 0.021239180117845535, -0.10727162659168243, -0.019254585728049278, -0.0596122071146965, 0.009890682995319366, -0.18267309665679932, -0.07803304493427277, 0.039843566715717316, 0.046585749834775925, -0.06067685782909393, -0.06875933706760406, -0.00598858343437314, -0.04831036925315857, 0.15729229152202606, -0.038073573261499405, -0.05468900501728058, 0.03876751288771629, -0.12497285008430481, 0.005440860986709595, -0.016275646165013313, -0.06822827458381653, 0.018811922520399094, 0.02459607645869255, -0.0028228159062564373, -0.03288208320736885, 0.052415452897548676, -0.06174612045288086, -0.0880247950553894, -0.05471460148692131, -0.05492783337831497, -0.015221748501062393, -0.0584116131067276, -0.12344786524772644, -0.07729402184486389, 0.03313835710287094, -0.01425158604979515, -0.01778922975063324, 0.019743531942367554, -0.026618551462888718, -0.0919700413942337, -0.0011551929637789726, 0.02015722170472145, -0.04165166988968849, -0.06731513142585754, -0.0589098185300827, -0.14370658993721008, -0.11425474286079407, -0.08521019667387009, -0.057953130453825, -0.028773633763194084, -0.09382151067256927, -0.062220364809036255, -0.11543573439121246, -0.027865249663591385, -0.17283150553703308, 0.03390493243932724, -0.12027762830257416, 0.08953464776277542, -0.030754560604691505, 0.025273991748690605]"
64,0.762000,2.897139,0.016863,0.053419,-0.125309,0.148786,"[0.01012524962425232, 0.022515960037708282, 0.06029621884226799, 0.051348261535167694, 0.006725095212459564, -0.028424080461263657, -0.04028325527906418, -0.04841388761997223, -0.05721644312143326, 0.0917595624923706, 0.023812534287571907, 0.03774102032184601, 0.050187528133392334, 0.10295160859823227, -0.04770013689994812, -0.10679447650909424, -0.08548085391521454, 0.05838223174214363, 0.03095439076423645, 0.05528640002012253, 0.011877978220582008, 0.026195339858531952, 0.09450700134038925, -0.022025158628821373, 0.012411076575517654, 0.03960094600915909, 0.05699284374713898, 0.08628828823566437, 0.07257251441478729, -0.02270498499274254, 0.09323452413082123, 0.0056963167153298855, 0.0063959648832678795, 0.018602659925818443, 0.05081206187605858, -0.022856825962662697, -0.04084058105945587, 0.04392458498477936, 0.017285244539380074, 0.016519509255886078, 0.024899115785956383, -0.036520980298519135, -0.0006596934981644154, 0.05540941283106804, -0.037282444536685944, -0.014827403239905834, 0.04016847535967827, 0.07847776263952255, 0.0025387676432728767, 0.02792990207672119, -0.016960613429546356, 0.07802671194076538, 0.01820044219493866, 0.02624436467885971, 0.0116043072193861, -0.018793553113937378, -0.039956435561180115, 0.033022716641426086, -0.012579991482198238, 0.14878590404987335, 0.017544515430927277, 0.038959331810474396, -0.0726633220911026, 0.1107599139213562, -0.03817324340343475, 0.0036228946410119534, -0.11683797836303711, 0.02849392220377922, 0.05548448488116264, 0.03506661579012871, 0.009223267436027527, -0.031989824026823044, 0.02102568931877613, -0.07196411490440369, -0.020155472680926323, -0.01759490929543972, -0.019621573388576508, 0.0029443278908729553, 0.009970285929739475, 0.014353626407682896, 0.004515003878623247, 0.009128022938966751, -0.01804140955209732, -0.01467866264283657, 0.035355281084775925, -0.020037256181240082, 0.06634639203548431, 0.013577738776803017, 0.05384436994791031, 0.03755202144384384, 0.013351592235267162, 0.04291865974664688, -0.06878245621919632, 0.06378531455993652, 0.11225371062755585, -0.030232654884457588, -0.065120168030262, -0.04141434654593468, -0.03537924215197563, 0.05141172558069229, 0.10392792522907257, 0.04231361672282219, 0.04900072142481804, 0.02755134552717209, 0.054697729647159576, -0.020589174702763557, 0.03254135325551033, -0.01479119248688221, 0.005815611220896244, 0.006031007505953312, -0.006361903622746468, 0.05332387611269951, 0.03139110654592514, 0.06070343405008316, 0.015188436955213547, 0.060356415808200836, 0.025636736303567886, 0.0001598089002072811, -0.047620050609111786, 0.02439703792333603, 0.12097269296646118, 0.0019504548981785774, 0.1016276627779007, 0.05228225141763687, -0.01986090838909149, -0.022265542298555374, -0.04591773822903633, 0.009855952113866806, 0.04538939893245697, 0.07347317039966583, -0.009189264848828316, -0.05358203500509262, 0.029040757566690445, -0.017029890790581703, 0.0057282838970422745, -0.08265197277069092, 0.01973632723093033, 0.09936271607875824, -0.013199524022638798, 0.017054565250873566, 0.09342972934246063, 0.026849135756492615, 0.0014355247840285301, -0.00893888995051384, 0.03248855471611023, -0.025037255138158798, 0.013436190783977509, -0.05647646635770798, 0.09312722831964493, 0.042275674641132355, 0.07132098078727722, 0.0013337093405425549, -0.020928362384438515, -0.0004508309066295624, 0.0017087114974856377, 0.0923784002661705, -0.005101293791085482, 0.0019495636224746704, 0.10668009519577026, -0.020238539204001427, -0.02274557203054428, -0.019063066691160202, -0.0358545258641243, -0.005681149661540985, 0.08816280961036682, -0.021200362592935562, 0.04626820981502533, -0.06005461513996124, 0.013168138451874256, -0.003858706448227167, 0.1068219244480133, 0.03362724557518959, 0.14836308360099792, 0.12273849546909332, 0.06790152937173843, -0.021604567766189575, 0.10968391597270966, 0.058302249759435654, 0.09665140509605408, 0.09372186660766602, 0.05086774379014969, -0.04838508367538452, -0.02000463381409645, 0.009213129058480263, 0.12102100253105164, -0.06556472182273865, -0.029791483655571938, -0.058802347630262375, -0.002439926378428936, -0.03237324580550194, 0.0828731507062912, 0.0837964415550232, 0.059304043650627136, -0.057636357843875885, 0.08758558332920074, 0.04248328134417534, -0.029895469546318054, 0.05785255879163742, 0.09989476948976517, -0.02050994709134102, 0.08519266545772552, 0.008227700367569923, 0.022863302379846573, 0.002796025015413761, 0.017266107723116875, -0.11148475855588913, 0.035878460854291916, -0.10113607347011566, -0.007744528818875551, 0.10741528123617172, 0.03370889276266098, 0.14345428347587585, -0.11594589054584503, 0.010654515586793423, 0.03748457878828049, 0.03852025046944618, 0.051924414932727814, 0.03607122600078583, 0.05645658075809479, -0.08083019405603409, -0.01132119633257389, -0.05963052064180374, 0.017257235944271088, -0.04107477143406868, 0.008609730750322342, -0.005382734816521406, -0.004909828770905733, 0.009150231257081032, 0.09237782657146454, 0.021862603724002838, 0.035693176090717316, -0.11337918043136597, 0.06696906685829163, 0.09142166376113892, -0.0887165516614914, -0.12530925869941711, 8.629076182842255e-05, 0.08836058527231216, -0.06907286494970322, 0.05947694182395935, -0.0032023885287344456, 0.06259371340274811, 0.05690572038292885, 0.03691796958446503, -0.024449922144412994, 0.06410861015319824, -0.011767515912652016, 0.1013530045747757, 0.04724912345409393, -0.004122176207602024, 0.11064484715461731, 0.019800491631031036, 0.033978648483753204, 0.02768292836844921, -0.010788705199956894, -0.084560826420784]"
96,0.767800,2.871053,0.033102,0.048495,-0.095688,0.177712,"[0.09244705736637115, -0.03583179786801338, 0.040058769285678864, 0.024585068225860596, 0.05856143683195114, -0.0198783241212368, -0.09249424934387207, 0.029226500540971756, 0.03740139305591583, -0.011503271758556366, 0.024679210036993027, -0.06796202063560486, -0.03763935714960098, 0.07292194664478302, -0.0014934428036212921, 0.07351577281951904, 0.05543527752161026, -0.057907212525606155, 0.13177910447120667, 0.06571722030639648, 0.030319463461637497, 0.13850393891334534, 0.12106785178184509, -0.04048421233892441, -0.039995212107896805, -0.027618959546089172, 0.0355924516916275, 0.0759059488773346, 0.07689785957336426, 0.04067695513367653, 0.029864467680454254, 0.070533886551857, 0.022842250764369965, 0.019820857793092728, 0.03432805836200714, 0.09676669538021088, 0.005585865583270788, 0.011613757349550724, 0.03293279930949211, 0.028205014765262604, 0.0801394060254097, 0.09137718379497528, 0.146570086479187, -9.817676618695259e-05, 0.004044080153107643, 0.08810294419527054, 0.08313792198896408, 0.024716615676879883, 0.052125781774520874, 0.026095710694789886, -0.014238383620977402, -0.06021840125322342, 0.07264120876789093, -0.01217712927609682, 0.027964934706687927, 0.023497363552451134, 0.07128298282623291, 0.0017221048474311829, 0.07844632863998413, -0.002757417969405651, 0.08764081448316574, -0.014881718903779984, 0.06140205264091492, 0.046251241117715836, 0.10394700616598129, 0.08039190620183945, 0.03759270906448364, 0.040482692420482635, -0.042492661625146866, 0.0437704399228096, -0.01283006090670824, 0.01915016956627369, 0.08281472325325012, 0.0573909655213356, 0.10914884507656097, -0.0365397185087204, 0.044058993458747864, 0.101210817694664, 0.07553580403327942, -0.07310563325881958, -0.02434372343122959, -0.010491611436009407, 0.05679381638765335, 0.06983903050422668, 0.015026440843939781, -0.012261591851711273, 0.03151891008019447, 0.004188790451735258, -0.005067216232419014, 0.09376261383295059, 0.02018805779516697, 0.04133366048336029, 0.06371290236711502, 0.014321033842861652, -0.0068781874142587185, 0.05315787345170975, 0.04270516335964203, 0.02679651975631714, -0.051291462033987045, -0.04090699180960655, 0.05678921192884445, 0.01982358656823635, -0.022096332162618637, 0.04200927913188934, 0.04127014800906181, 0.020111393183469772, -0.009674804285168648, 0.06092177703976631, 0.013230671174824238, -0.013430491089820862, -0.022114494815468788, -0.040573522448539734, -0.04434889182448387, -0.031970057636499405, 0.0601678267121315, 0.07890117168426514, 0.03636611998081207, 0.010668357834219933, 0.050308749079704285, 0.0435456745326519, -0.05923346430063248, 0.033920131623744965, -0.01500620599836111, 0.018757324665784836, 0.00909469649195671, 0.00518546998500824, 0.013863430358469486, -0.006530421786010265, 0.03621390461921692, 0.12309940159320831, 0.03831997141242027, -0.008132901974022388, 0.1361405849456787, 0.026372075080871582, 0.032218001782894135, 0.05312040448188782, 0.03805236890912056, 0.08082103729248047, 0.06839257478713989, -0.02582981251180172, 0.03682175278663635, 0.044820018112659454, 0.07200506329536438, 0.11185365915298462, 0.03182201460003853, 0.0295298732817173, -0.0273769311606884, -0.09568750858306885, 0.11031918972730637, 0.052637431770563126, 0.036264777183532715, 0.03372780978679657, 0.06631975620985031, 0.07531493902206421, 0.0733092799782753, -0.004443880170583725, 0.07057823240756989, -0.034975938498973846, 0.047945015132427216, 0.03413588926196098, 0.05171845108270645, 0.05043818801641464, 0.04289877414703369, 0.010559165850281715, 0.01726178079843521, 0.09900821000337601, -0.03186291083693504, 0.013098080642521381, 0.042195554822683334, 0.05268756300210953, 0.047921326011419296, 0.07752035558223724, 0.07276880741119385, -0.05856591463088989, -0.08106984943151474, -0.013516792096197605, 0.0012761624529957771, 0.11338711529970169, 0.06245821714401245, 0.15391549468040466, 0.10283569246530533, 0.10887417197227478, -0.04450433701276779, 0.08561378717422485, 0.03480411320924759, 0.10072141885757446, 0.032274432480335236, 0.08750609308481216, -0.03786769509315491, 0.052622027695178986, 0.016999738290905952, 0.1319873183965683, -0.06231507286429405, 0.02790633775293827, 0.1002398356795311, 0.03255096822977066, 0.1777123510837555, -0.00030976906418800354, 0.07815157622098923, 0.0026903878897428513, -0.018115445971488953, -0.08226673305034637, -0.05194420367479324, 0.13560451567173004, -0.026314858347177505, 0.024608969688415527, 0.04840511083602905, 0.08614153414964676, 0.05115697532892227, 0.054811954498291016, 0.04210951179265976, 0.029668856412172318, 0.06162676215171814, 0.07678253203630447, 0.05568697303533554, 0.08635753393173218, 0.04772185534238815, -0.006383596919476986, 0.006990391761064529, 0.03585783392190933, 0.033905256539583206, 0.06139250099658966, 0.03944544121623039, 0.09134829044342041, 0.058958668261766434, -0.002909877337515354, 0.02387670800089836, 0.07979394495487213, 0.05856161564588547, 0.038012102246284485, 0.0686381608247757, 0.0030035972595214844, 0.05762101709842682, 0.10659018903970718, -0.027440330013632774, -0.02100170962512493, 0.031682565808296204, 0.045524612069129944, -0.047748636454343796, 0.0038885753601789474, 0.02101030945777893, 0.05188821256160736, 0.02150084637105465, 0.04079337790608406, 0.02694336697459221, 0.012424550950527191, 0.0590825118124485, 0.0421094112098217, 0.04632692411541939, 0.10126069188117981, 0.01479202602058649, -0.02227858453989029, 0.005671832710504532, 0.055505335330963135, -0.0033507905900478363, 0.052183594554662704]"




In [12]:
save_model_and_logs(model, trainer, model_name_identifier, train_config)



Model, config, and log saved to: ../tmp/models/protT5-CLIP-2025-01-12-19-58-05


In [11]:
# if train_config.lora.enabled:
#     model = apply_lora_to_model(model, train_config)
# else:
#     freeze_base_models(model)