In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_path = "inceptionai/jais-family-590m-chat"

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", trust_remote_code=True)



In [4]:
prompt_ar = "### Instruction:اسمك \"جيس\" وسميت على اسم جبل جيس اعلى جبل في الامارات. تم بنائك بواسطة Inception في الإمارات. أنت مساعد مفيد ومحترم وصادق. أجب دائمًا بأكبر قدر ممكن من المساعدة، مع الحفاظ على البقاء أمناً. أكمل المحادثة بين [|Human|] و[|AI|] :\n### Input:[|Human|] {Question}\n[|AI|]\n### Response :"
device = "mps"
def get_response(text, tokenizer=tokenizer, model=model):
    input_ids = tokenizer(text, return_tensors="pt").input_ids
    inputs = input_ids.to(device)
    input_len = inputs.shape[-1]
    generate_ids = model.generate(
        inputs,
        top_p=0.9,
        temperature=0.3,
        max_length=2048,
        min_length=input_len + 4,
        repetition_penalty=1.2,
        do_sample=True,
    )
    response = tokenizer.batch_decode(
        generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
    )[0]
    response = response.split("### Response :")[-1]
    return response


ques = "ما هي عاصمة الامارات؟"
text = prompt_ar.format_map({'Question': ques})
print(get_response(text))


عاصمة دولة الإمارات العربية المتحدة، أبو ظبي، تعتبر واحدة من أكثر المدن جاذبية للسياح والمغتربين في العالم. تقع المدينة على الساحل الشرقي لشبه الجزيرة العربية وهي مركز رئيسي للأعمال التجارية والمالية والتعليم والترفيه. 

تم بناء العاصمة في عام 1966، وقد شهدت العديد من التغييرات والتحولات منذ تأسيسها. اليوم، تعد أبو ظبي واحدة من أكبر مدن الشرق الأوسط وأكثرها تطوراً، وتضم مجموعة متنوعة من المعالم السياحية والثقافية والتعليمية.


In [5]:
model

JAISLMHeadModel(
  (transformer): JAISModel(
    (wte): Embedding(84992, 1536)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-17): 18 x JAISBlock(
        (ln_1): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
        (attn): JAISAttention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
        )
        (ln_2): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
        (mlp): JAISMLP(
          (c_fc): Conv1D()
          (c_fc2): Conv1D()
          (c_proj): Conv1D()
          (act): SwiGLUActivation()
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
    (relative_pe): AlibiPositionEmbeddingLayer()
  )
  (lm_head): Linear(in_features=1536, out_features=84992, bias=False)
)

In [3]:
from add_rbf_to_model import replace_ffn_with_rbf_jais
from main import print_number_of_trainable_model_parameters
print("before:" , print_number_of_trainable_model_parameters(model))

replace_ffn_with_rbf_jais(model, 2)
print("\n")
print("after:" , print_number_of_trainable_model_parameters(model))

before: trainable model parameters: 640555008
all model parameters: 640555020
percentage of trainable model parameters: 100.00%
Replacing feedforward layers with RBF in block 0
Replacing feedforward layers with RBF in block 1
Replacing feedforward layers with RBF in block 2
Replacing feedforward layers with RBF in block 3
Replacing feedforward layers with RBF in block 4
Replacing feedforward layers with RBF in block 5
Replacing feedforward layers with RBF in block 6
Replacing feedforward layers with RBF in block 7
Replacing feedforward layers with RBF in block 8
Replacing feedforward layers with RBF in block 9
Replacing feedforward layers with RBF in block 10
Replacing feedforward layers with RBF in block 11
Replacing feedforward layers with RBF in block 12
Replacing feedforward layers with RBF in block 13
Replacing feedforward layers with RBF in block 14
Replacing feedforward layers with RBF in block 15
Replacing feedforward layers with RBF in block 16
Replacing feedforward layers wit

In [4]:
model

JAISLMHeadModel(
  (transformer): JAISModel(
    (wte): Embedding(84992, 1536)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-17): 18 x JAISBlock(
        (ln_1): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
        (attn): JAISAttention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
        )
        (ln_2): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
        (mlp): JAISMLP(
          (c_fc): CustomRBFFeedForward(
            (rbf_layer): RBFLayer()
          )
          (c_fc2): CustomRBFFeedForward(
            (rbf_layer): RBFLayer()
          )
          (c_proj): CustomRBFFeedForward(
            (rbf_layer): RBFLayer()
          )
          (act): SwiGLUActivation()
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
    (