<a href="https://colab.research.google.com/github/ShubhamZoro/Generative_AI/blob/main/summarize/PEFT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%pip install -U datasets==2.17.0

%pip install --upgrade pip
%pip install --disable-pip-version-check \
    torch==1.13.1 \
    torchdata==0.5.1 --quiet

%pip install \
    transformers==4.27.2 \
    evaluate==0.4.0 \
    rouge_score==0.1.2 \
    loralib==0.1.1 \
    peft==0.3.0 --quiet

In [6]:
from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig, TrainingArguments, Trainer
import torch
import time
import evaluate
import pandas as pd
import numpy as np

In [7]:
huggingface_dataset_name = "xsum"

dataset = load_dataset(huggingface_dataset_name)

dataset

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


Downloading builder script:   0%|          | 0.00/5.76k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/6.24k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/255M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.00M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/204045 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11332 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/11334 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 204045
    })
    validation: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11332
    })
    test: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11334
    })
})

In [None]:
model_name='google/flan-t5-base'

original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
def print_number_of_trainable_model_parameters(model):
    trainable_model_params = 0
    all_model_params = 0
    for _, param in model.named_parameters():
        all_model_params += param.numel()
        if param.requires_grad:
            trainable_model_params += param.numel()
    return f"trainable model parameters: {trainable_model_params}\nall model parameters: {all_model_params}\npercentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"

print(print_number_of_trainable_model_parameters(original_model))

In [None]:
index = 200

dialogue = dataset['test']['document'][index]
summary = dataset['test']['summary'][index]

prompt = f"""
Summarize the following conversation.

{dialogue}

Summary:
"""

inputs = tokenizer(prompt, return_tensors='pt')
output = tokenizer.decode(
    original_model.generate(
        inputs["input_ids"],
        max_new_tokens=200,
    )[0],
    skip_special_tokens=True
)

dash_line = '-'.join('' for x in range(100))
print(dash_line)
print(f'INPUT PROMPT:\n{prompt}')
print(dash_line)
print(f'BASELINE HUMAN SUMMARY:\n{summary}\n')
print(dash_line)
print(f'MODEL GENERATION - ZERO SHOT:\n{output}')

In [None]:
def tokenize_function(example):
    start_prompt = 'Summarize the following conversation.\n\n'
    end_prompt = '\n\nSummary: '
    prompt = [start_prompt + dialogue + end_prompt for dialogue in example["document"]]
    example['input_ids'] = tokenizer(prompt, padding="max_length", truncation=True, return_tensors="pt").input_ids
    example['labels'] = tokenizer(example["summary"], padding="max_length", truncation=True, return_tensors="pt").input_ids

    return example

# The dataset actually contains 3 diff splits: train, validation, test.
# The tokenize_function code is handling all data across all splits in batches.
tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(['id', 'document', 'summary',])
tokenized_datasets = tokenized_datasets.filter(lambda example, index: index % 100 == 0, with_indices=True)

In [None]:
print(f"Shapes of the datasets:")
print(f"Training: {tokenized_datasets['train'].shape}")
print(f"Validation: {tokenized_datasets['validation'].shape}")
print(f"Test: {tokenized_datasets['test'].shape}")

print(tokenized_datasets)

In [None]:
output_dir = f'dialogue-summary-training'

training_args = TrainingArguments(
    output_dir=output_dir,
    learning_rate=1e-5,
    num_train_epochs=1,
    weight_decay=0.01,
    logging_steps=1,
    max_steps=1
)

trainer = Trainer(
    model=original_model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation']
)

In [None]:
trainer.train()

In [None]:
index = 200
dialogue = dataset['test']['document'][index]
human_baseline_summary = dataset['test']['summary'][index]

prompt = f"""
Summarize the following conversation.

{dialogue}

Summary:
"""

input_ids = tokenizer(prompt, return_tensors="pt").input_ids
input_ids = input_ids.to(next(original_model.parameters()).device)
original_model_outputs = original_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200, num_beams=1))
original_model_text_output = tokenizer.decode(original_model_outputs[0], skip_special_tokens=True)

print(dash_line)
print(f'BASELINE HUMAN SUMMARY:\n{human_baseline_summary}')
print(dash_line)
print(f'ORIGINAL MODEL:\n{original_model_text_output}')


In [None]:
from peft import LoraConfig, get_peft_model, TaskType

lora_config = LoraConfig(
    r=32, # Rank
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM # FLAN-T5
)

In [None]:
peft_model = get_peft_model(original_model,
                            lora_config)
print(print_number_of_trainable_model_parameters(peft_model))

In [None]:
output_dir = f'Shubham_peft-dialogue-summary-training'

peft_training_args = TrainingArguments(
    output_dir=output_dir,
    auto_find_batch_size=True,
    learning_rate=1e-3, # Higher learning rate than full fine-tuning.
    num_train_epochs=1,
    logging_steps=1,
    max_steps=1
    # push_to_hub_model_id=output_dir,
    # push_to_hub_organization="huggingface",
    # push_to_hub_token="",
)

peft_trainer = Trainer(
    model=peft_model,
    args=peft_training_args,
    train_dataset=tokenized_datasets["train"],
)

In [None]:
peft_trainer.train()

In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
peft_model.push_to_hub("ShubhamZoro/FLan-T5-Summarize",
                  use_auth_token=True,
                  commit_message="basic training",
                  )

In [19]:
%pip install -U datasets==2.17.0

%pip install --upgrade pip
%pip install --disable-pip-version-check \
    torch==1.13.1 \
    torchdata==0.5.1 --quiet

%pip install \
    transformers==4.27.2 \
    evaluate==0.4.0 \
    rouge_score==0.1.2 \
    loralib==0.1.1 \
    peft==0.3.0 --quiet


[0m

In [20]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig, TrainingArguments, Trainer
import torch

In [21]:
import torch
from peft import PeftModel, PeftConfig
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

peft_model_id = "ShubhamZoro/FLan-T5-Summarize"
config = PeftConfig.from_pretrained(peft_model_id)
model1 = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path, return_dict=True, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)

In [22]:
index = 100
dialogue = dataset['test']['document'][index]
baseline_human_summary = dataset['test']['summary'][index]

prompt = f"""
Summarize the following conversation.

{dialogue}

Summary: """

input_ids = tokenizer(prompt,padding='max_length', return_tensors="pt").input_ids
input_ids = input_ids.to(next(model1.parameters()).device)

peft_model_outputs = model1.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200, num_beams=1))
peft_model_text_output = tokenizer.decode(peft_model_outputs[0], skip_special_tokens=True)

print(dash_line)
print(f'BASELINE HUMAN SUMMARY:\n{baseline_human_summary}')

print(dash_line)
print(f'PEFT MODEL: {peft_model_text_output}')

Token indices sequence length is longer than the specified maximum sequence length for this model (631 > 512). Running this sequence through the model will result in indexing errors


RuntimeError: 
  #ifdef __HIPCC__
  #define ERROR_UNSUPPORTED_CAST ;
  // corresponds to aten/src/ATen/native/cuda/thread_constants.h
  #define CUDA_OR_ROCM_NUM_THREADS 256
  // corresponds to aten/src/ATen/cuda/detail/OffsetCalculator.cuh
  #define MAX_DIMS 16
  #ifndef __forceinline__
  #define __forceinline__ inline __attribute__((always_inline))
  #endif
  #else
  //TODO use _assert_fail, because assert is disabled in non-debug builds
  #define ERROR_UNSUPPORTED_CAST assert(false);
  #define CUDA_OR_ROCM_NUM_THREADS 128
  #define MAX_DIMS 25
  #endif
  #define POS_INFINITY __int_as_float(0x7f800000)
  #define INFINITY POS_INFINITY
  #define NEG_INFINITY __int_as_float(0xff800000)
  #define NAN __int_as_float(0x7fffffff)

  typedef long long int int64_t;
  typedef unsigned int uint32_t;
  typedef signed char int8_t;
  typedef unsigned char uint8_t;  // NOTE: this MUST be "unsigned char"! "char" is equivalent to "signed char"
  typedef short int16_t;
  static_assert(sizeof(int64_t) == 8, "expected size does not match");
  static_assert(sizeof(uint32_t) == 4, "expected size does not match");
  static_assert(sizeof(int8_t) == 1, "expected size does not match");
  constexpr int num_threads = CUDA_OR_ROCM_NUM_THREADS;
  constexpr int thread_work_size = 4; // TODO: make template substitution once we decide where those vars live
  constexpr int block_work_size = thread_work_size * num_threads;

  
  
  
  namespace std {
  
  using ::signbit;
  using ::isfinite;
  using ::isinf;
  using ::isnan;
  
  using ::abs;
  
  using ::acos;
  using ::acosf;
  using ::asin;
  using ::asinf;
  using ::atan;
  using ::atanf;
  using ::atan2;
  using ::atan2f;
  using ::ceil;
  using ::ceilf;
  using ::cos;
  using ::cosf;
  using ::cosh;
  using ::coshf;
  
  using ::exp;
  using ::expf;
  
  using ::fabs;
  using ::fabsf;
  using ::floor;
  using ::floorf;
  
  using ::fmod;
  using ::fmodf;
  
  using ::frexp;
  using ::frexpf;
  using ::ldexp;
  using ::ldexpf;
  
  using ::log;
  using ::logf;
  
  using ::log10;
  using ::log10f;
  using ::modf;
  using ::modff;
  
  using ::pow;
  using ::powf;
  
  using ::sin;
  using ::sinf;
  using ::sinh;
  using ::sinhf;
  
  using ::sqrt;
  using ::sqrtf;
  using ::tan;
  using ::tanf;
  
  using ::tanh;
  using ::tanhf;
  
  using ::acosh;
  using ::acoshf;
  using ::asinh;
  using ::asinhf;
  using ::atanh;
  using ::atanhf;
  using ::cbrt;
  using ::cbrtf;
  
  using ::copysign;
  using ::copysignf;
  
  using ::erf;
  using ::erff;
  using ::erfc;
  using ::erfcf;
  using ::exp2;
  using ::exp2f;
  using ::expm1;
  using ::expm1f;
  using ::fdim;
  using ::fdimf;
  using ::fmaf;
  using ::fma;
  using ::fmax;
  using ::fmaxf;
  using ::fmin;
  using ::fminf;
  using ::hypot;
  using ::hypotf;
  using ::ilogb;
  using ::ilogbf;
  using ::lgamma;
  using ::lgammaf;
  using ::llrint;
  using ::llrintf;
  using ::llround;
  using ::llroundf;
  using ::log1p;
  using ::log1pf;
  using ::log2;
  using ::log2f;
  using ::logb;
  using ::logbf;
  using ::lrint;
  using ::lrintf;
  using ::lround;
  using ::lroundf;
  
  using ::nan;
  using ::nanf;
  
  using ::nearbyint;
  using ::nearbyintf;
  using ::nextafter;
  using ::nextafterf;
  using ::remainder;
  using ::remainderf;
  using ::remquo;
  using ::remquof;
  using ::rint;
  using ::rintf;
  using ::round;
  using ::roundf;
  using ::scalbln;
  using ::scalblnf;
  using ::scalbn;
  using ::scalbnf;
  using ::tgamma;
  using ::tgammaf;
  using ::trunc;
  using ::truncf;
  
  } // namespace std
  
  

  // NB: Order matters for this macro; it is relied upon in
  // _promoteTypesLookup and the serialization format.
  // Note, some types have ctype as void because we don't support them in codegen
  #define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \
  _(uint8_t, Byte) /* 0 */                               \
  _(int8_t, Char) /* 1 */                                \
  _(int16_t, Short) /* 2 */                              \
  _(int, Int) /* 3 */                                    \
  _(int64_t, Long) /* 4 */                               \
  _(at::Half, Half) /* 5 */                                  \
  _(float, Float) /* 6 */                                \
  _(double, Double) /* 7 */                              \
  _(std::complex<at::Half>, ComplexHalf) /* 8 */        \
  _(std::complex<float>, ComplexFloat) /* 9 */                          \
  _(std::complex<double>, ComplexDouble) /* 10 */                         \
  _(bool, Bool) /* 11 */                                 \
  _(void, QInt8) /* 12 */                          \
  _(void, QUInt8) /* 13 */                        \
  _(void, QInt32) /* 14 */                        \
  _(at::BFloat16, BFloat16) /* 15 */                             \

  #define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_QINT(_)       \
  _(uint8_t, Byte)                                                 \
  _(int8_t, Char)                                                  \
  _(int16_t, Short)                                                \
  _(int, Int)                                                      \
  _(int64_t, Long)                                                 \
  _(at::Half, Half)                                                \
  _(float, Float)                                                  \
  _(double, Double)                                                \
  _(std::complex<at::Half>, ComplexHalf)                           \
  _(std::complex<float>, ComplexFloat)                             \
  _(std::complex<double>, ComplexDouble)                           \
  _(bool, Bool)                                                    \
  _(at::BFloat16, BFloat16)


  enum class ScalarType : int8_t {
  #define DEFINE_ENUM(_1, n) n,
  AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_ENUM)
  #undef DEFINE_ENUM
      Undefined,
  NumOptions
  };

  template <typename T, int size>
  struct Array {
  T data[size];

  __device__ T operator[](int i) const {
      return data[i];
  }
  __device__ T& operator[](int i) {
      return data[i];
  }
  Array() = default;
  Array(const Array&) = default;
  Array& operator=(const Array&) = default;
  __device__ Array(T x) {
    for (int i = 0; i < size; i++) {
      data[i] = x;
    }
  }
  };

  
  
  
  
  



  template <typename T>
  struct DivMod {
  T div;
  T mod;

  __device__ DivMod(T _div, T _mod) {
      div = _div;
      mod = _mod;
  }
  };

  //<unsigned int>
  struct IntDivider {
  IntDivider() = default;

  __device__ inline unsigned int div(unsigned int n) const {
  unsigned int t = __umulhi(n, m1);
  return (t + n) >> shift;
  }

  __device__ inline unsigned int mod(unsigned int n) const {
  return n - div(n) * divisor;
  }

  __device__ inline DivMod<unsigned int> divmod(unsigned int n) const {
  unsigned int q = div(n);
  return DivMod<unsigned int>(q, n - q * divisor);
  }

  unsigned int divisor;  // d above.
  unsigned int m1;  // Magic number: m' above.
  unsigned int shift;  // Shift amounts.
  };

  template <int NARGS>
  struct TrivialOffsetCalculator {
    // The offset for each argument. Wrapper around fixed-size array.
    // The offsets are in # of elements, not in bytes.
    Array<unsigned int, NARGS> get(unsigned int linear_idx) const {
      Array<unsigned int, NARGS> offsets;
      #pragma unroll
      for (int arg = 0; arg < NARGS; arg++) {
        offsets[arg] = linear_idx;
      }
      return offsets;
    }
  };

  template<int NARGS>
  struct OffsetCalculator {
  OffsetCalculator() = default;
  __device__ __forceinline__ Array<unsigned int, NARGS> get(unsigned int linear_idx) const {
      Array<unsigned int, NARGS> offsets;
      #pragma unroll
      for (int arg = 0; arg < NARGS; ++arg) {
      offsets[arg] = 0;
      }

      #pragma unroll
      for (int dim = 0; dim < MAX_DIMS; ++dim) {
      if (dim == dims) {
          break;
      }

      auto divmod = sizes_[dim].divmod(linear_idx);
      linear_idx = divmod.div;

      #pragma unroll
      for (int arg = 0; arg < NARGS; ++arg) {
          offsets[arg] += divmod.mod * strides_[dim][arg];
      }
      //printf("offset calc thread dim size stride offset %d %d %d %d %d %d %d %d\n",
      //threadIdx.x, dim, sizes_[dim].divisor, strides_[dim][0], offsets[0], linear_idx, divmod.div, divmod.mod);
      }
      return offsets;
  }

    int dims;
    IntDivider sizes_[MAX_DIMS];
    // NOTE: this approach will not support nInputs == 0
    unsigned int strides_[MAX_DIMS][NARGS];
  };



  #define C10_HOST_DEVICE __host__ __device__
  #define C10_DEVICE __device__
  #if defined(__clang__) && defined(__HIP__)
  #ifndef __forceinline__
  #define __forceinline__ inline __attribute__((always_inline))
  #endif
  // until ROCm support for kernel asserts is restored
  #define assert(expr) (static_cast<void>(0))
  #endif

  template <typename T>
  __device__ __forceinline__ T WARP_SHFL_DOWN(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
  {
  #if defined(__clang__) && defined(__HIP__)
    return __shfl_down(value, delta, width);
  #else
    return __shfl_down_sync(mask, value, delta, width);
  #endif
  }


  #if 0
  template <typename T>
  __device__ __forceinline__ std::complex<T> WARP_SHFL_DOWN(std::complex<T> value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
  {
    return std::complex<T>(
  #if defined(__clang__) && defined(__HIP__)
        __shfl_down(value.real(), delta, width),
        __shfl_down(value.imag(), delta, width));
  #else
        __shfl_down_sync(mask, value.real(), delta, width),
        __shfl_down_sync(mask, value.imag(), delta, width));
  #endif
  }
  #endif

  // aligned vector generates vectorized load/store on CUDA
  template<typename scalar_t, int vec_size>
  struct alignas(sizeof(scalar_t) * vec_size) aligned_vector {
    scalar_t val[vec_size];
  };


  C10_HOST_DEVICE static void reduce_fraction(size_t &numerator, size_t &denominator) {
    // get GCD of num and denom using Euclid's algorithm.
    // Can replace this with std::gcd if we ever support c++17.
    size_t a = denominator;
    size_t b = numerator;
    while (b != 0) {
        a %= b;
        // swap(a,b)
        size_t tmp = a;
        a = b;
        b = tmp;
    }

    // a is now the GCD
    numerator /= a;
    denominator /= a;
  }




  struct ReduceConfig {
  //has to match host-side ReduceConfig in the eager code
  static constexpr int BLOCK_X = 0;
  static constexpr int BLOCK_Y = 1;
  static constexpr int CTA = 2;

  static constexpr int input_vec_size = 4;
  int element_size_bytes;
  int num_inputs;
  int num_outputs;
  int step_input = 1;
  int step_output = 1;
  int ctas_per_output = 1;
  int input_mult[3] = {0, 0, 0};
  int output_mult[2] = {0, 0};

  int block_width;
  int block_height;
  int num_threads;

  bool vectorize_input = false;
  int output_vec_size = 1;

  C10_HOST_DEVICE bool should_block_x_reduce() const {
    return input_mult[BLOCK_X] != 0;
  }

  C10_HOST_DEVICE bool should_block_y_reduce() const {
    return input_mult[BLOCK_Y] != 0;
  }

  C10_HOST_DEVICE bool should_global_reduce() const {
    return input_mult[CTA] != 0;
  }

  C10_DEVICE bool should_store(int output_idx) const {
    return output_idx < num_outputs &&
      (!should_block_x_reduce() || threadIdx.x == 0) &&
      (!should_block_y_reduce() || threadIdx.y == 0);
  }

  C10_DEVICE bool should_reduce_tail() const {
    return (!should_block_y_reduce() || threadIdx.y == 0) &&
      (!should_global_reduce() || blockIdx.y == 0);
  }

  C10_HOST_DEVICE int input_idx() const {
    int lane = threadIdx.x;
    int warp = threadIdx.y;
    int cta2 = blockIdx.y;
    return (lane * input_mult[BLOCK_X] +
            warp * input_mult[BLOCK_Y] +
            cta2 * input_mult[CTA]);
  }

  template <int output_vec_size>
  C10_HOST_DEVICE int output_idx() const {
    int lane = threadIdx.x;
    int warp = threadIdx.y;
    int cta1 = blockIdx.x;
    return (lane * output_mult[BLOCK_X] +
            warp * output_mult[BLOCK_Y] +
            cta1 * step_output) * output_vec_size;
  }

  C10_DEVICE int shared_memory_offset(int offset) const {
    return threadIdx.x + (threadIdx.y + offset) * blockDim.x;
  }

  C10_DEVICE int staging_memory_offset(int cta2) const {
    int offset = cta2 + blockIdx.x * gridDim.y;
    if (!should_block_x_reduce()) {
      offset = threadIdx.x + offset * blockDim.x;
    }
    return offset;
  }


  };


//TODO this will need to be different for more generic reduction functions
namespace reducer {

  using scalar_t = int64_t;
  using arg_t = int64_t;
  using out_scalar_t = int64_t;


  inline __device__ arg_t combine(arg_t a, arg_t b) { return a * b; }

  inline __device__ out_scalar_t project(arg_t arg) {
    return (out_scalar_t) arg;
  }

  inline __device__ arg_t warp_shfl_down(arg_t arg, int offset) {
    return WARP_SHFL_DOWN(arg, offset);
  }

  inline __device__ arg_t translate_idx(arg_t acc, int64_t /*idx*/) {
    return acc;
  }

  // wrap a normal reduction that ignores the index
  inline __device__ arg_t reduce(arg_t acc, arg_t val, int64_t idx) {
     return combine(acc, val);
  }
}


struct ReduceJitOp {
  using scalar_t = int64_t;
  using arg_t = int64_t;
  using out_scalar_t = int64_t;

  using InputCalculator = OffsetCalculator<1>;
  using OutputCalculator = OffsetCalculator<2>;

//   static constexpr bool can_accumulate_in_output =
//     std::is_convertible<arg_t, out_scalar_t>::value
//     && std::is_convertible<out_scalar_t, arg_t>::value;

  static constexpr int input_vec_size = ReduceConfig::input_vec_size;

  arg_t ident;
  ReduceConfig config;
  InputCalculator input_calc;
  OutputCalculator output_calc;
  const void* src;
  const char* dst[2]; //it accepts at most two destinations
  // acc_buf used for accumulation among sub Tensor Iterator when accumulation on
  // output is not permissible
  void* acc_buf;
  // cta_buf used for accumulation between blocks during global reduction
  void* cta_buf;
  int* semaphores;
  int64_t base_idx;
  bool accumulate;
  bool final_output;
  int noutputs;


  C10_DEVICE void run() const {
    extern __shared__ char shared_memory[];
    uint32_t output_idx = config.output_idx<1>();
    uint32_t input_idx = config.input_idx();
    auto base_offsets1 = output_calc.get(output_idx)[1];

    using arg_vec_t = Array<arg_t, 1>;
    arg_vec_t value;

    if (output_idx < config.num_outputs && input_idx < config.num_inputs) {
      const scalar_t* input_slice = (const scalar_t*)((const char*)src + base_offsets1);

      value = thread_reduce<1>(input_slice);
    }

    if (config.should_block_y_reduce()) {
      value = block_y_reduce<1>(value, shared_memory);
    }
    if (config.should_block_x_reduce()) {
      value = block_x_reduce<1>(value, shared_memory);
    }

    using out_ptr_vec_t = Array<out_scalar_t*, 1>;
    using offset_vec_t = Array<uint32_t, 1>;
    offset_vec_t base_offsets;
    out_ptr_vec_t out;

    #pragma unroll
    for (int i = 0; i < 1; i++) {
      base_offsets[i] = output_calc.get(output_idx + i)[0];
      out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]);
    }

    arg_vec_t* acc = nullptr;
    if (acc_buf != nullptr) {
      size_t numerator = sizeof(arg_t);
      size_t denominator = sizeof(out_scalar_t);
      reduce_fraction(numerator, denominator);
      acc = (arg_vec_t*)((char*)acc_buf + (base_offsets[0] * numerator / denominator));
    }

    if (config.should_global_reduce()) {
      value = global_reduce<1>(value, acc, shared_memory);
    } else if (config.should_store(output_idx)) {
      if (accumulate) {
        #pragma unroll
        for (int i = 0; i < 1; i++) {
          value[i] = reducer::translate_idx(value[i], base_idx);
        }
      }

      if (acc == nullptr) {
        if (accumulate) {
          value = accumulate_in_output<1>(out, value);
        }
        if (final_output) {
          set_results_to_output<1>(value, base_offsets);
        } else {
          #pragma unroll
          for (int i = 0; i < 1; i++) {
            *(out[i]) = get_accumulated_output(out[i], value[i]);
          }
        }
      } else {
        if (accumulate) {
          #pragma unroll
          for (int i = 0; i < 1; i++) {
            value[i] = reducer::combine((*acc)[i], value[i]);
          }
        }
        if (final_output) {
          set_results_to_output<1>(value, base_offsets);
        } else {
          *acc = value;
        }
      }
    }
  }

  template <int output_vec_size>
  C10_DEVICE Array<arg_t, output_vec_size> thread_reduce(const scalar_t* data) const {
    if (config.vectorize_input) {
      assert(output_vec_size == 1);
      // reduce at the header of input_slice where memory is not aligned,
      // so that thread_reduce will have an aligned memory to work on.
      return {input_vectorized_thread_reduce_impl(data)};
    } else {
      uint32_t element_stride = input_calc.strides_[0][0] / sizeof(scalar_t);
      bool is_contiguous = (input_calc.dims == 1 && element_stride == 1);
      if (is_contiguous) {
        return thread_reduce_impl<output_vec_size>(data, [](uint32_t idx) { return idx; });
      } else if (input_calc.dims == 1) {
        return thread_reduce_impl<output_vec_size>(data, [&](uint32_t idx) { return idx * element_stride; });
      } else {
        return thread_reduce_impl<output_vec_size>(data, [&](uint32_t idx) { return input_calc.get(idx)[0] / sizeof(scalar_t); });
      }
    }
  }

  C10_DEVICE arg_t input_vectorized_thread_reduce_impl(const scalar_t* data) const {
    uint32_t end = config.num_inputs;

    // Handle the head of input slice where data is not aligned
    arg_t value = ident;
    constexpr int align_bytes = alignof(aligned_vector<scalar_t, input_vec_size>);
    constexpr int align_elements = align_bytes / sizeof(scalar_t);
    int shift = ((int64_t)data) % align_bytes / sizeof(scalar_t);
    if (shift > 0) {
      data -= shift;
      end += shift;
      if(threadIdx.x >= shift && threadIdx.x < align_elements && config.should_reduce_tail()){
        value = reducer::reduce(value, data[threadIdx.x], threadIdx.x - shift);
      }
      end -= align_elements;
      data += align_elements;
      shift = align_elements - shift;
    }

    // Do the vectorized reduction
    using load_t = aligned_vector<scalar_t, input_vec_size>;

    uint32_t idx = config.input_idx();
    const uint32_t stride = config.step_input;

    // Multiple accumulators to remove dependency between unrolled loops.
    arg_t value_list[input_vec_size];
    value_list[0] = value;

    #pragma unroll
    for (int i = 1; i < input_vec_size; i++) {
      value_list[i] = ident;
    }

    scalar_t values[input_vec_size];

    load_t *values_vector = reinterpret_cast<load_t*>(&values[0]);

    while (idx * input_vec_size + input_vec_size - 1 < end) {
      *values_vector = reinterpret_cast<const load_t*>(data)[idx];
      #pragma unroll
      for (uint32_t i = 0; i < input_vec_size; i++) {
        value_list[i] = reducer::reduce(value_list[i], values[i], shift + idx * input_vec_size + i);
      }
      idx += stride;
    }

    // tail
    uint32_t tail_start = end - end % input_vec_size;
    if (config.should_reduce_tail()) {
      int idx = tail_start + threadIdx.x;
      if (idx < end) {
        value_list[0] = reducer::reduce(value_list[0], data[idx], idx + shift);
      }
    }

    // combine accumulators
    #pragma unroll
    for (int i = 1; i < input_vec_size; i++) {
      value_list[0] = reducer::combine(value_list[0], value_list[i]);
    }
    return value_list[0];
  }

  template <int output_vec_size, typename offset_calc_t>
  C10_DEVICE Array<arg_t, output_vec_size> thread_reduce_impl(const scalar_t* data_, offset_calc_t calc) const {
    uint32_t idx = config.input_idx();
    const uint32_t end = config.num_inputs;
    const uint32_t stride = config.step_input;
    const int vt0=4;

    using arg_vec_t = Array<arg_t, output_vec_size>;
    using load_t = aligned_vector<scalar_t, output_vec_size>;
    const load_t* data = reinterpret_cast<const load_t*>(data_);

    // Multiple accumulators to remove dependency between unrolled loops.
    arg_vec_t value_list[vt0];

    #pragma unroll
    for (int i = 0; i < vt0; i++) {
      #pragma unroll
      for (int j = 0; j < output_vec_size; j++) {
        value_list[i][j] = ident;
      }
    }

    load_t values[vt0];

    while (idx + (vt0 - 1) * stride < end) {
      #pragma unroll
      for (uint32_t i = 0; i < vt0; i++) {
        values[i] = data[calc(idx + i * stride) / output_vec_size];
      }
      #pragma unroll
      for (uint32_t i = 0; i < vt0; i++) {
        #pragma unroll
        for (uint32_t j = 0; j < output_vec_size; j++) {
          value_list[i][j] = reducer::reduce(value_list[i][j], values[i].val[j], idx + i * stride);
        }
      }
      idx += stride * vt0;
    }

    // tail
    int idx_ = idx;
    #pragma unroll
    for (uint32_t i = 0; i < vt0; i++) {
      if (idx >= end) {
        break;
      }
      values[i] = data[calc(idx) / output_vec_size];
      idx += stride;
    }
    idx = idx_;
    #pragma unroll
    for (uint32_t i = 0; i < vt0; i++) {
      if (idx >= end) {
        break;
      }
      #pragma unroll
      for (uint32_t j = 0; j < output_vec_size; j++) {
        value_list[i][j] = reducer::reduce(value_list[i][j], values[i].val[j], idx);
      }
      idx += stride;
    }

    // combine accumulators
    #pragma unroll
    for (int i = 1; i < vt0; i++) {
      #pragma unroll
      for (uint32_t j = 0; j < output_vec_size; j++) {
        value_list[0][j] = reducer::combine(value_list[0][j], value_list[i][j]);
      }
    }
    return value_list[0];
  }
  template <int output_vec_size>
  C10_DEVICE Array<arg_t, output_vec_size> block_x_reduce(Array<arg_t, output_vec_size> value, char* shared_memory) const {
    using args_vec_t = Array<arg_t, output_vec_size>;
    int dim_x = blockDim.x;
    args_vec_t* shared = (args_vec_t*)shared_memory;
    if (dim_x > warpSize) {
      int address_base = threadIdx.x + threadIdx.y*blockDim.x;
      shared[address_base] = value;
      for (int offset = dim_x/2; offset >= warpSize; offset >>= 1) {
        __syncthreads();
        if (threadIdx.x < offset && threadIdx.x + offset < blockDim.x) {
          args_vec_t other = shared[address_base + offset];
          #pragma unroll
          for (int i = 0; i < output_vec_size; i++) {
            value[i] = reducer::combine(value[i], other[i]);
          }
          shared[address_base] = value;
        }
      }
      dim_x = warpSize;
    }

    __syncthreads();

    for (int offset = 1; offset < dim_x; offset <<= 1) {
      #pragma unroll
      for (int i = 0; i < output_vec_size; i++) {
        arg_t other = reducer::warp_shfl_down(value[i], offset);
        value[i] = reducer::combine(value[i], other);
      }
    }
    return value;
  }

  template <int output_vec_size>
  C10_DEVICE Array<arg_t, output_vec_size> block_y_reduce(Array<arg_t, output_vec_size> value, char* shared_memory) const {
    using args_vec_t = Array<arg_t, output_vec_size>;
    args_vec_t* shared = (args_vec_t*)shared_memory;
    shared[config.shared_memory_offset(0)] = value;
    for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) {
      __syncthreads();
      if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {
        args_vec_t other = shared[config.shared_memory_offset(offset)];
        #pragma unroll
        for (int i = 0; i < output_vec_size; i++) {
          value[i] = reducer::combine(value[i], other[i]);
        }
        shared[config.shared_memory_offset(0)] = value;
      }
    }
    return value;
  }
  

  C10_DEVICE bool mark_block_finished() const {
    __shared__ bool is_last_block_done_shared;

    __syncthreads();
    if (threadIdx.x == 0 && threadIdx.y == 0) {
      int prev_blocks_finished = atomicAdd(&semaphores[blockIdx.x], 1);
      is_last_block_done_shared = (prev_blocks_finished == gridDim.y - 1);
    }

    __syncthreads();

    return is_last_block_done_shared;
  }

  template <int output_vec_size>
  C10_DEVICE Array<arg_t, output_vec_size> accumulate_in_output(
    Array<out_scalar_t*, output_vec_size> out,
    Array<arg_t, output_vec_size> value
  ) const {
    Array<arg_t, output_vec_size> ret;
    #pragma unroll
    for (int i = 0; i < output_vec_size; i++) {
      ret[i] = reducer::combine(*(out[i]), value[i]);
    }
    return ret;
  }


  C10_DEVICE out_scalar_t get_accumulated_output(
    out_scalar_t* out, arg_t value
  ) const {
    assert(!final_output);
    return (out_scalar_t)value;
  }

  template<class T>
  C10_DEVICE void set_results(const T x, const uint32_t base_offset) const {
    assert(noutputs == 1);
    auto res = (out_scalar_t*)((char*)dst[0] + base_offset);
    *res = x;
  }

//TODO - multi-output reduction - we won't be able to use thrust::pair
//just explicitly specify typed output reads/writes
//Currently implemented for max of two outputs
//   template<class T1, class T2>
//   C10_DEVICE void set_results(const thrust::pair<T1, T2> x, const index_t base_offset) const {
//     if (noutputs >= 1) {
//       auto res0 = (T1*)((char*)dst[0] + base_offset);
//       *res0 = x.first;
//     }
//     if (noutputs >= 2) {
//       // base offset is computed assuming element size being sizeof(T1), so we need to make a
//       // correction to obtain the correct base offset
//       auto res1 = (T2*) ((char *) dst[1] + base_offset / sizeof(T1) * sizeof(T2));
//       *res1 = x.second;
//     }
//   }

  template <int output_vec_size>
  C10_DEVICE void set_results_to_output(Array<arg_t, output_vec_size> value, Array<uint32_t, output_vec_size> base_offset) const {
    assert(final_output);
    #pragma unroll
    for (int i = 0; i < output_vec_size; i++) {
      set_results(reducer::project(value[i]), base_offset[i]);
    }
  }

  template <int output_vec_size>
  C10_DEVICE Array<arg_t, output_vec_size> global_reduce(Array<arg_t, output_vec_size> value, Array<arg_t, output_vec_size> *acc, char* shared_memory) const {
    using arg_vec_t = Array<arg_t, output_vec_size>;
    using out_ptr_vec_t = Array<out_scalar_t*, output_vec_size>;
    using offset_vec_t = Array<uint32_t, output_vec_size>;

    arg_vec_t* reduce_buffer = (arg_vec_t*)cta_buf;
    uint32_t output_idx = config.output_idx<output_vec_size>();
    offset_vec_t base_offsets;
    out_ptr_vec_t out;

    #pragma unroll
    for (int i = 0; i < output_vec_size; i++) {
      base_offsets[i] = output_calc.get(output_idx + i)[0];
      out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]);
    }

    bool should_store = config.should_store(output_idx);
    if (should_store) {
      uint32_t offset = config.staging_memory_offset(blockIdx.y);
      reduce_buffer[offset] = value;
    }

    __threadfence(); // make sure writes are globally visible
    __syncthreads(); // if multiple warps in this block wrote to staging, make sure they're all done
    bool is_last_block_done = mark_block_finished();

    if (is_last_block_done) {
      value = ident;
      if (config.should_block_x_reduce()) {
        uint32_t input_offset = threadIdx.x + threadIdx.y * blockDim.x;
        uint32_t step = blockDim.x * blockDim.y;
        for (; input_offset < config.ctas_per_output; input_offset += step) {
          uint32_t idx = config.staging_memory_offset(input_offset);
          arg_vec_t next = reduce_buffer[idx];
          #pragma unroll
          for (int i = 0; i < output_vec_size; i++) {
            value[i] = reducer::combine(value[i], next[i]);
          }
        }
      } else {
        uint32_t input_offset = threadIdx.y;
        uint32_t step = blockDim.y;
        for (; input_offset < config.ctas_per_output; input_offset += step) {
          uint32_t idx = config.staging_memory_offset(input_offset);
          arg_vec_t next = reduce_buffer[idx];
          #pragma unroll
          for (int i = 0; i < output_vec_size; i++) {
            value[i] = reducer::combine(value[i], next[i]);
          }
        }
      }
      value = block_y_reduce(value, shared_memory);
      if (config.should_block_x_reduce()) {
        value = block_x_reduce<output_vec_size>(value, shared_memory);
      }
      if (should_store) {
        if (accumulate) {
          #pragma unroll
          for (int i = 0; i < output_vec_size; i++) {
            value[i] = reducer::translate_idx(value[i], base_idx);
          }
        }

        if (acc == nullptr) {
          if (accumulate) {
            value = accumulate_in_output<output_vec_size>(out, value);
          }
          if (final_output) {
            set_results_to_output<output_vec_size>(value, base_offsets);
          } else {
            #pragma unroll
            for (int i = 0; i < output_vec_size; i++) {
              *(out[i]) = get_accumulated_output(out[i], value[i]);
            }
          }
        } else {
          if (accumulate) {
            #pragma unroll
            for (int i = 0; i < output_vec_size; i++) {
              value[i] = reducer::combine((*acc)[i], value[i]);
            }
          }
          if (final_output) {
            set_results_to_output<output_vec_size>(value, base_offsets);
          } else {
            *acc = value;
          }
        }
      }
    }

    return value;
  }
};

extern "C"
__launch_bounds__(512, 4)
__global__ void reduction_prod_kernel(ReduceJitOp r){
  r.run();
}
nvrtc: error: failed to open libnvrtc-builtins.so.12.1.
  Make sure that libnvrtc-builtins.so.12.1 is installed correctly.

In [23]:
max_sequence_length = 512  # Adjust according to the maximum sequence length supported by your model

# Truncate or split the input sequence
dialogue_tokens = tokenizer(dialogue, padding='max_length', max_length=max_sequence_length, return_tensors="pt")
prompt_tokens = tokenizer(prompt, padding='max_length', max_length=max_sequence_length, return_tensors="pt")

# Concatenate the truncated/split input sequences
input_ids = torch.cat([dialogue_tokens.input_ids, prompt_tokens.input_ids], dim=1)
attention_mask = torch.cat([dialogue_tokens.attention_mask, prompt_tokens.attention_mask], dim=1)

# Ensure the input tensors are on the same device as the model's parameters
input_ids = input_ids.to(next(model1.parameters()).device)
attention_mask = attention_mask.to(next(model1.parameters()).device)

# Generate the summary
peft_model_outputs = model1.generate(input_ids=input_ids, attention_mask=attention_mask, generation_config=GenerationConfig(max_new_tokens=200, num_beams=1))
peft_model_text_output = tokenizer.decode(peft_model_outputs[0], skip_special_tokens=True)

# Print the results
print(dash_line)
print(f'BASELINE HUMAN SUMMARY:\n{baseline_human_summary}')

print(dash_line)
print(f'PEFT MODEL: {peft_model_text_output}')


RuntimeError: 
  #ifdef __HIPCC__
  #define ERROR_UNSUPPORTED_CAST ;
  // corresponds to aten/src/ATen/native/cuda/thread_constants.h
  #define CUDA_OR_ROCM_NUM_THREADS 256
  // corresponds to aten/src/ATen/cuda/detail/OffsetCalculator.cuh
  #define MAX_DIMS 16
  #ifndef __forceinline__
  #define __forceinline__ inline __attribute__((always_inline))
  #endif
  #else
  //TODO use _assert_fail, because assert is disabled in non-debug builds
  #define ERROR_UNSUPPORTED_CAST assert(false);
  #define CUDA_OR_ROCM_NUM_THREADS 128
  #define MAX_DIMS 25
  #endif
  #define POS_INFINITY __int_as_float(0x7f800000)
  #define INFINITY POS_INFINITY
  #define NEG_INFINITY __int_as_float(0xff800000)
  #define NAN __int_as_float(0x7fffffff)

  typedef long long int int64_t;
  typedef unsigned int uint32_t;
  typedef signed char int8_t;
  typedef unsigned char uint8_t;  // NOTE: this MUST be "unsigned char"! "char" is equivalent to "signed char"
  typedef short int16_t;
  static_assert(sizeof(int64_t) == 8, "expected size does not match");
  static_assert(sizeof(uint32_t) == 4, "expected size does not match");
  static_assert(sizeof(int8_t) == 1, "expected size does not match");
  constexpr int num_threads = CUDA_OR_ROCM_NUM_THREADS;
  constexpr int thread_work_size = 4; // TODO: make template substitution once we decide where those vars live
  constexpr int block_work_size = thread_work_size * num_threads;

  
  
  
  namespace std {
  
  using ::signbit;
  using ::isfinite;
  using ::isinf;
  using ::isnan;
  
  using ::abs;
  
  using ::acos;
  using ::acosf;
  using ::asin;
  using ::asinf;
  using ::atan;
  using ::atanf;
  using ::atan2;
  using ::atan2f;
  using ::ceil;
  using ::ceilf;
  using ::cos;
  using ::cosf;
  using ::cosh;
  using ::coshf;
  
  using ::exp;
  using ::expf;
  
  using ::fabs;
  using ::fabsf;
  using ::floor;
  using ::floorf;
  
  using ::fmod;
  using ::fmodf;
  
  using ::frexp;
  using ::frexpf;
  using ::ldexp;
  using ::ldexpf;
  
  using ::log;
  using ::logf;
  
  using ::log10;
  using ::log10f;
  using ::modf;
  using ::modff;
  
  using ::pow;
  using ::powf;
  
  using ::sin;
  using ::sinf;
  using ::sinh;
  using ::sinhf;
  
  using ::sqrt;
  using ::sqrtf;
  using ::tan;
  using ::tanf;
  
  using ::tanh;
  using ::tanhf;
  
  using ::acosh;
  using ::acoshf;
  using ::asinh;
  using ::asinhf;
  using ::atanh;
  using ::atanhf;
  using ::cbrt;
  using ::cbrtf;
  
  using ::copysign;
  using ::copysignf;
  
  using ::erf;
  using ::erff;
  using ::erfc;
  using ::erfcf;
  using ::exp2;
  using ::exp2f;
  using ::expm1;
  using ::expm1f;
  using ::fdim;
  using ::fdimf;
  using ::fmaf;
  using ::fma;
  using ::fmax;
  using ::fmaxf;
  using ::fmin;
  using ::fminf;
  using ::hypot;
  using ::hypotf;
  using ::ilogb;
  using ::ilogbf;
  using ::lgamma;
  using ::lgammaf;
  using ::llrint;
  using ::llrintf;
  using ::llround;
  using ::llroundf;
  using ::log1p;
  using ::log1pf;
  using ::log2;
  using ::log2f;
  using ::logb;
  using ::logbf;
  using ::lrint;
  using ::lrintf;
  using ::lround;
  using ::lroundf;
  
  using ::nan;
  using ::nanf;
  
  using ::nearbyint;
  using ::nearbyintf;
  using ::nextafter;
  using ::nextafterf;
  using ::remainder;
  using ::remainderf;
  using ::remquo;
  using ::remquof;
  using ::rint;
  using ::rintf;
  using ::round;
  using ::roundf;
  using ::scalbln;
  using ::scalblnf;
  using ::scalbn;
  using ::scalbnf;
  using ::tgamma;
  using ::tgammaf;
  using ::trunc;
  using ::truncf;
  
  } // namespace std
  
  

  // NB: Order matters for this macro; it is relied upon in
  // _promoteTypesLookup and the serialization format.
  // Note, some types have ctype as void because we don't support them in codegen
  #define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \
  _(uint8_t, Byte) /* 0 */                               \
  _(int8_t, Char) /* 1 */                                \
  _(int16_t, Short) /* 2 */                              \
  _(int, Int) /* 3 */                                    \
  _(int64_t, Long) /* 4 */                               \
  _(at::Half, Half) /* 5 */                                  \
  _(float, Float) /* 6 */                                \
  _(double, Double) /* 7 */                              \
  _(std::complex<at::Half>, ComplexHalf) /* 8 */        \
  _(std::complex<float>, ComplexFloat) /* 9 */                          \
  _(std::complex<double>, ComplexDouble) /* 10 */                         \
  _(bool, Bool) /* 11 */                                 \
  _(void, QInt8) /* 12 */                          \
  _(void, QUInt8) /* 13 */                        \
  _(void, QInt32) /* 14 */                        \
  _(at::BFloat16, BFloat16) /* 15 */                             \

  #define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_QINT(_)       \
  _(uint8_t, Byte)                                                 \
  _(int8_t, Char)                                                  \
  _(int16_t, Short)                                                \
  _(int, Int)                                                      \
  _(int64_t, Long)                                                 \
  _(at::Half, Half)                                                \
  _(float, Float)                                                  \
  _(double, Double)                                                \
  _(std::complex<at::Half>, ComplexHalf)                           \
  _(std::complex<float>, ComplexFloat)                             \
  _(std::complex<double>, ComplexDouble)                           \
  _(bool, Bool)                                                    \
  _(at::BFloat16, BFloat16)


  enum class ScalarType : int8_t {
  #define DEFINE_ENUM(_1, n) n,
  AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_ENUM)
  #undef DEFINE_ENUM
      Undefined,
  NumOptions
  };

  template <typename T, int size>
  struct Array {
  T data[size];

  __device__ T operator[](int i) const {
      return data[i];
  }
  __device__ T& operator[](int i) {
      return data[i];
  }
  Array() = default;
  Array(const Array&) = default;
  Array& operator=(const Array&) = default;
  __device__ Array(T x) {
    for (int i = 0; i < size; i++) {
      data[i] = x;
    }
  }
  };

  
  
  
  
  



  template <typename T>
  struct DivMod {
  T div;
  T mod;

  __device__ DivMod(T _div, T _mod) {
      div = _div;
      mod = _mod;
  }
  };

  //<unsigned int>
  struct IntDivider {
  IntDivider() = default;

  __device__ inline unsigned int div(unsigned int n) const {
  unsigned int t = __umulhi(n, m1);
  return (t + n) >> shift;
  }

  __device__ inline unsigned int mod(unsigned int n) const {
  return n - div(n) * divisor;
  }

  __device__ inline DivMod<unsigned int> divmod(unsigned int n) const {
  unsigned int q = div(n);
  return DivMod<unsigned int>(q, n - q * divisor);
  }

  unsigned int divisor;  // d above.
  unsigned int m1;  // Magic number: m' above.
  unsigned int shift;  // Shift amounts.
  };

  template <int NARGS>
  struct TrivialOffsetCalculator {
    // The offset for each argument. Wrapper around fixed-size array.
    // The offsets are in # of elements, not in bytes.
    Array<unsigned int, NARGS> get(unsigned int linear_idx) const {
      Array<unsigned int, NARGS> offsets;
      #pragma unroll
      for (int arg = 0; arg < NARGS; arg++) {
        offsets[arg] = linear_idx;
      }
      return offsets;
    }
  };

  template<int NARGS>
  struct OffsetCalculator {
  OffsetCalculator() = default;
  __device__ __forceinline__ Array<unsigned int, NARGS> get(unsigned int linear_idx) const {
      Array<unsigned int, NARGS> offsets;
      #pragma unroll
      for (int arg = 0; arg < NARGS; ++arg) {
      offsets[arg] = 0;
      }

      #pragma unroll
      for (int dim = 0; dim < MAX_DIMS; ++dim) {
      if (dim == dims) {
          break;
      }

      auto divmod = sizes_[dim].divmod(linear_idx);
      linear_idx = divmod.div;

      #pragma unroll
      for (int arg = 0; arg < NARGS; ++arg) {
          offsets[arg] += divmod.mod * strides_[dim][arg];
      }
      //printf("offset calc thread dim size stride offset %d %d %d %d %d %d %d %d\n",
      //threadIdx.x, dim, sizes_[dim].divisor, strides_[dim][0], offsets[0], linear_idx, divmod.div, divmod.mod);
      }
      return offsets;
  }

    int dims;
    IntDivider sizes_[MAX_DIMS];
    // NOTE: this approach will not support nInputs == 0
    unsigned int strides_[MAX_DIMS][NARGS];
  };



  #define C10_HOST_DEVICE __host__ __device__
  #define C10_DEVICE __device__
  #if defined(__clang__) && defined(__HIP__)
  #ifndef __forceinline__
  #define __forceinline__ inline __attribute__((always_inline))
  #endif
  // until ROCm support for kernel asserts is restored
  #define assert(expr) (static_cast<void>(0))
  #endif

  template <typename T>
  __device__ __forceinline__ T WARP_SHFL_DOWN(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
  {
  #if defined(__clang__) && defined(__HIP__)
    return __shfl_down(value, delta, width);
  #else
    return __shfl_down_sync(mask, value, delta, width);
  #endif
  }


  #if 0
  template <typename T>
  __device__ __forceinline__ std::complex<T> WARP_SHFL_DOWN(std::complex<T> value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
  {
    return std::complex<T>(
  #if defined(__clang__) && defined(__HIP__)
        __shfl_down(value.real(), delta, width),
        __shfl_down(value.imag(), delta, width));
  #else
        __shfl_down_sync(mask, value.real(), delta, width),
        __shfl_down_sync(mask, value.imag(), delta, width));
  #endif
  }
  #endif

  // aligned vector generates vectorized load/store on CUDA
  template<typename scalar_t, int vec_size>
  struct alignas(sizeof(scalar_t) * vec_size) aligned_vector {
    scalar_t val[vec_size];
  };


  C10_HOST_DEVICE static void reduce_fraction(size_t &numerator, size_t &denominator) {
    // get GCD of num and denom using Euclid's algorithm.
    // Can replace this with std::gcd if we ever support c++17.
    size_t a = denominator;
    size_t b = numerator;
    while (b != 0) {
        a %= b;
        // swap(a,b)
        size_t tmp = a;
        a = b;
        b = tmp;
    }

    // a is now the GCD
    numerator /= a;
    denominator /= a;
  }




  struct ReduceConfig {
  //has to match host-side ReduceConfig in the eager code
  static constexpr int BLOCK_X = 0;
  static constexpr int BLOCK_Y = 1;
  static constexpr int CTA = 2;

  static constexpr int input_vec_size = 4;
  int element_size_bytes;
  int num_inputs;
  int num_outputs;
  int step_input = 1;
  int step_output = 1;
  int ctas_per_output = 1;
  int input_mult[3] = {0, 0, 0};
  int output_mult[2] = {0, 0};

  int block_width;
  int block_height;
  int num_threads;

  bool vectorize_input = false;
  int output_vec_size = 1;

  C10_HOST_DEVICE bool should_block_x_reduce() const {
    return input_mult[BLOCK_X] != 0;
  }

  C10_HOST_DEVICE bool should_block_y_reduce() const {
    return input_mult[BLOCK_Y] != 0;
  }

  C10_HOST_DEVICE bool should_global_reduce() const {
    return input_mult[CTA] != 0;
  }

  C10_DEVICE bool should_store(int output_idx) const {
    return output_idx < num_outputs &&
      (!should_block_x_reduce() || threadIdx.x == 0) &&
      (!should_block_y_reduce() || threadIdx.y == 0);
  }

  C10_DEVICE bool should_reduce_tail() const {
    return (!should_block_y_reduce() || threadIdx.y == 0) &&
      (!should_global_reduce() || blockIdx.y == 0);
  }

  C10_HOST_DEVICE int input_idx() const {
    int lane = threadIdx.x;
    int warp = threadIdx.y;
    int cta2 = blockIdx.y;
    return (lane * input_mult[BLOCK_X] +
            warp * input_mult[BLOCK_Y] +
            cta2 * input_mult[CTA]);
  }

  template <int output_vec_size>
  C10_HOST_DEVICE int output_idx() const {
    int lane = threadIdx.x;
    int warp = threadIdx.y;
    int cta1 = blockIdx.x;
    return (lane * output_mult[BLOCK_X] +
            warp * output_mult[BLOCK_Y] +
            cta1 * step_output) * output_vec_size;
  }

  C10_DEVICE int shared_memory_offset(int offset) const {
    return threadIdx.x + (threadIdx.y + offset) * blockDim.x;
  }

  C10_DEVICE int staging_memory_offset(int cta2) const {
    int offset = cta2 + blockIdx.x * gridDim.y;
    if (!should_block_x_reduce()) {
      offset = threadIdx.x + offset * blockDim.x;
    }
    return offset;
  }


  };


//TODO this will need to be different for more generic reduction functions
namespace reducer {

  using scalar_t = int64_t;
  using arg_t = int64_t;
  using out_scalar_t = int64_t;


  inline __device__ arg_t combine(arg_t a, arg_t b) { return a * b; }

  inline __device__ out_scalar_t project(arg_t arg) {
    return (out_scalar_t) arg;
  }

  inline __device__ arg_t warp_shfl_down(arg_t arg, int offset) {
    return WARP_SHFL_DOWN(arg, offset);
  }

  inline __device__ arg_t translate_idx(arg_t acc, int64_t /*idx*/) {
    return acc;
  }

  // wrap a normal reduction that ignores the index
  inline __device__ arg_t reduce(arg_t acc, arg_t val, int64_t idx) {
     return combine(acc, val);
  }
}


struct ReduceJitOp {
  using scalar_t = int64_t;
  using arg_t = int64_t;
  using out_scalar_t = int64_t;

  using InputCalculator = OffsetCalculator<1>;
  using OutputCalculator = OffsetCalculator<2>;

//   static constexpr bool can_accumulate_in_output =
//     std::is_convertible<arg_t, out_scalar_t>::value
//     && std::is_convertible<out_scalar_t, arg_t>::value;

  static constexpr int input_vec_size = ReduceConfig::input_vec_size;

  arg_t ident;
  ReduceConfig config;
  InputCalculator input_calc;
  OutputCalculator output_calc;
  const void* src;
  const char* dst[2]; //it accepts at most two destinations
  // acc_buf used for accumulation among sub Tensor Iterator when accumulation on
  // output is not permissible
  void* acc_buf;
  // cta_buf used for accumulation between blocks during global reduction
  void* cta_buf;
  int* semaphores;
  int64_t base_idx;
  bool accumulate;
  bool final_output;
  int noutputs;


  C10_DEVICE void run() const {
    extern __shared__ char shared_memory[];
    uint32_t output_idx = config.output_idx<1>();
    uint32_t input_idx = config.input_idx();
    auto base_offsets1 = output_calc.get(output_idx)[1];

    using arg_vec_t = Array<arg_t, 1>;
    arg_vec_t value;

    if (output_idx < config.num_outputs && input_idx < config.num_inputs) {
      const scalar_t* input_slice = (const scalar_t*)((const char*)src + base_offsets1);

      value = thread_reduce<1>(input_slice);
    }

    if (config.should_block_y_reduce()) {
      value = block_y_reduce<1>(value, shared_memory);
    }
    if (config.should_block_x_reduce()) {
      value = block_x_reduce<1>(value, shared_memory);
    }

    using out_ptr_vec_t = Array<out_scalar_t*, 1>;
    using offset_vec_t = Array<uint32_t, 1>;
    offset_vec_t base_offsets;
    out_ptr_vec_t out;

    #pragma unroll
    for (int i = 0; i < 1; i++) {
      base_offsets[i] = output_calc.get(output_idx + i)[0];
      out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]);
    }

    arg_vec_t* acc = nullptr;
    if (acc_buf != nullptr) {
      size_t numerator = sizeof(arg_t);
      size_t denominator = sizeof(out_scalar_t);
      reduce_fraction(numerator, denominator);
      acc = (arg_vec_t*)((char*)acc_buf + (base_offsets[0] * numerator / denominator));
    }

    if (config.should_global_reduce()) {
      value = global_reduce<1>(value, acc, shared_memory);
    } else if (config.should_store(output_idx)) {
      if (accumulate) {
        #pragma unroll
        for (int i = 0; i < 1; i++) {
          value[i] = reducer::translate_idx(value[i], base_idx);
        }
      }

      if (acc == nullptr) {
        if (accumulate) {
          value = accumulate_in_output<1>(out, value);
        }
        if (final_output) {
          set_results_to_output<1>(value, base_offsets);
        } else {
          #pragma unroll
          for (int i = 0; i < 1; i++) {
            *(out[i]) = get_accumulated_output(out[i], value[i]);
          }
        }
      } else {
        if (accumulate) {
          #pragma unroll
          for (int i = 0; i < 1; i++) {
            value[i] = reducer::combine((*acc)[i], value[i]);
          }
        }
        if (final_output) {
          set_results_to_output<1>(value, base_offsets);
        } else {
          *acc = value;
        }
      }
    }
  }

  template <int output_vec_size>
  C10_DEVICE Array<arg_t, output_vec_size> thread_reduce(const scalar_t* data) const {
    if (config.vectorize_input) {
      assert(output_vec_size == 1);
      // reduce at the header of input_slice where memory is not aligned,
      // so that thread_reduce will have an aligned memory to work on.
      return {input_vectorized_thread_reduce_impl(data)};
    } else {
      uint32_t element_stride = input_calc.strides_[0][0] / sizeof(scalar_t);
      bool is_contiguous = (input_calc.dims == 1 && element_stride == 1);
      if (is_contiguous) {
        return thread_reduce_impl<output_vec_size>(data, [](uint32_t idx) { return idx; });
      } else if (input_calc.dims == 1) {
        return thread_reduce_impl<output_vec_size>(data, [&](uint32_t idx) { return idx * element_stride; });
      } else {
        return thread_reduce_impl<output_vec_size>(data, [&](uint32_t idx) { return input_calc.get(idx)[0] / sizeof(scalar_t); });
      }
    }
  }

  C10_DEVICE arg_t input_vectorized_thread_reduce_impl(const scalar_t* data) const {
    uint32_t end = config.num_inputs;

    // Handle the head of input slice where data is not aligned
    arg_t value = ident;
    constexpr int align_bytes = alignof(aligned_vector<scalar_t, input_vec_size>);
    constexpr int align_elements = align_bytes / sizeof(scalar_t);
    int shift = ((int64_t)data) % align_bytes / sizeof(scalar_t);
    if (shift > 0) {
      data -= shift;
      end += shift;
      if(threadIdx.x >= shift && threadIdx.x < align_elements && config.should_reduce_tail()){
        value = reducer::reduce(value, data[threadIdx.x], threadIdx.x - shift);
      }
      end -= align_elements;
      data += align_elements;
      shift = align_elements - shift;
    }

    // Do the vectorized reduction
    using load_t = aligned_vector<scalar_t, input_vec_size>;

    uint32_t idx = config.input_idx();
    const uint32_t stride = config.step_input;

    // Multiple accumulators to remove dependency between unrolled loops.
    arg_t value_list[input_vec_size];
    value_list[0] = value;

    #pragma unroll
    for (int i = 1; i < input_vec_size; i++) {
      value_list[i] = ident;
    }

    scalar_t values[input_vec_size];

    load_t *values_vector = reinterpret_cast<load_t*>(&values[0]);

    while (idx * input_vec_size + input_vec_size - 1 < end) {
      *values_vector = reinterpret_cast<const load_t*>(data)[idx];
      #pragma unroll
      for (uint32_t i = 0; i < input_vec_size; i++) {
        value_list[i] = reducer::reduce(value_list[i], values[i], shift + idx * input_vec_size + i);
      }
      idx += stride;
    }

    // tail
    uint32_t tail_start = end - end % input_vec_size;
    if (config.should_reduce_tail()) {
      int idx = tail_start + threadIdx.x;
      if (idx < end) {
        value_list[0] = reducer::reduce(value_list[0], data[idx], idx + shift);
      }
    }

    // combine accumulators
    #pragma unroll
    for (int i = 1; i < input_vec_size; i++) {
      value_list[0] = reducer::combine(value_list[0], value_list[i]);
    }
    return value_list[0];
  }

  template <int output_vec_size, typename offset_calc_t>
  C10_DEVICE Array<arg_t, output_vec_size> thread_reduce_impl(const scalar_t* data_, offset_calc_t calc) const {
    uint32_t idx = config.input_idx();
    const uint32_t end = config.num_inputs;
    const uint32_t stride = config.step_input;
    const int vt0=4;

    using arg_vec_t = Array<arg_t, output_vec_size>;
    using load_t = aligned_vector<scalar_t, output_vec_size>;
    const load_t* data = reinterpret_cast<const load_t*>(data_);

    // Multiple accumulators to remove dependency between unrolled loops.
    arg_vec_t value_list[vt0];

    #pragma unroll
    for (int i = 0; i < vt0; i++) {
      #pragma unroll
      for (int j = 0; j < output_vec_size; j++) {
        value_list[i][j] = ident;
      }
    }

    load_t values[vt0];

    while (idx + (vt0 - 1) * stride < end) {
      #pragma unroll
      for (uint32_t i = 0; i < vt0; i++) {
        values[i] = data[calc(idx + i * stride) / output_vec_size];
      }
      #pragma unroll
      for (uint32_t i = 0; i < vt0; i++) {
        #pragma unroll
        for (uint32_t j = 0; j < output_vec_size; j++) {
          value_list[i][j] = reducer::reduce(value_list[i][j], values[i].val[j], idx + i * stride);
        }
      }
      idx += stride * vt0;
    }

    // tail
    int idx_ = idx;
    #pragma unroll
    for (uint32_t i = 0; i < vt0; i++) {
      if (idx >= end) {
        break;
      }
      values[i] = data[calc(idx) / output_vec_size];
      idx += stride;
    }
    idx = idx_;
    #pragma unroll
    for (uint32_t i = 0; i < vt0; i++) {
      if (idx >= end) {
        break;
      }
      #pragma unroll
      for (uint32_t j = 0; j < output_vec_size; j++) {
        value_list[i][j] = reducer::reduce(value_list[i][j], values[i].val[j], idx);
      }
      idx += stride;
    }

    // combine accumulators
    #pragma unroll
    for (int i = 1; i < vt0; i++) {
      #pragma unroll
      for (uint32_t j = 0; j < output_vec_size; j++) {
        value_list[0][j] = reducer::combine(value_list[0][j], value_list[i][j]);
      }
    }
    return value_list[0];
  }
  template <int output_vec_size>
  C10_DEVICE Array<arg_t, output_vec_size> block_x_reduce(Array<arg_t, output_vec_size> value, char* shared_memory) const {
    using args_vec_t = Array<arg_t, output_vec_size>;
    int dim_x = blockDim.x;
    args_vec_t* shared = (args_vec_t*)shared_memory;
    if (dim_x > warpSize) {
      int address_base = threadIdx.x + threadIdx.y*blockDim.x;
      shared[address_base] = value;
      for (int offset = dim_x/2; offset >= warpSize; offset >>= 1) {
        __syncthreads();
        if (threadIdx.x < offset && threadIdx.x + offset < blockDim.x) {
          args_vec_t other = shared[address_base + offset];
          #pragma unroll
          for (int i = 0; i < output_vec_size; i++) {
            value[i] = reducer::combine(value[i], other[i]);
          }
          shared[address_base] = value;
        }
      }
      dim_x = warpSize;
    }

    __syncthreads();

    for (int offset = 1; offset < dim_x; offset <<= 1) {
      #pragma unroll
      for (int i = 0; i < output_vec_size; i++) {
        arg_t other = reducer::warp_shfl_down(value[i], offset);
        value[i] = reducer::combine(value[i], other);
      }
    }
    return value;
  }

  template <int output_vec_size>
  C10_DEVICE Array<arg_t, output_vec_size> block_y_reduce(Array<arg_t, output_vec_size> value, char* shared_memory) const {
    using args_vec_t = Array<arg_t, output_vec_size>;
    args_vec_t* shared = (args_vec_t*)shared_memory;
    shared[config.shared_memory_offset(0)] = value;
    for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) {
      __syncthreads();
      if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {
        args_vec_t other = shared[config.shared_memory_offset(offset)];
        #pragma unroll
        for (int i = 0; i < output_vec_size; i++) {
          value[i] = reducer::combine(value[i], other[i]);
        }
        shared[config.shared_memory_offset(0)] = value;
      }
    }
    return value;
  }
  

  C10_DEVICE bool mark_block_finished() const {
    __shared__ bool is_last_block_done_shared;

    __syncthreads();
    if (threadIdx.x == 0 && threadIdx.y == 0) {
      int prev_blocks_finished = atomicAdd(&semaphores[blockIdx.x], 1);
      is_last_block_done_shared = (prev_blocks_finished == gridDim.y - 1);
    }

    __syncthreads();

    return is_last_block_done_shared;
  }

  template <int output_vec_size>
  C10_DEVICE Array<arg_t, output_vec_size> accumulate_in_output(
    Array<out_scalar_t*, output_vec_size> out,
    Array<arg_t, output_vec_size> value
  ) const {
    Array<arg_t, output_vec_size> ret;
    #pragma unroll
    for (int i = 0; i < output_vec_size; i++) {
      ret[i] = reducer::combine(*(out[i]), value[i]);
    }
    return ret;
  }


  C10_DEVICE out_scalar_t get_accumulated_output(
    out_scalar_t* out, arg_t value
  ) const {
    assert(!final_output);
    return (out_scalar_t)value;
  }

  template<class T>
  C10_DEVICE void set_results(const T x, const uint32_t base_offset) const {
    assert(noutputs == 1);
    auto res = (out_scalar_t*)((char*)dst[0] + base_offset);
    *res = x;
  }

//TODO - multi-output reduction - we won't be able to use thrust::pair
//just explicitly specify typed output reads/writes
//Currently implemented for max of two outputs
//   template<class T1, class T2>
//   C10_DEVICE void set_results(const thrust::pair<T1, T2> x, const index_t base_offset) const {
//     if (noutputs >= 1) {
//       auto res0 = (T1*)((char*)dst[0] + base_offset);
//       *res0 = x.first;
//     }
//     if (noutputs >= 2) {
//       // base offset is computed assuming element size being sizeof(T1), so we need to make a
//       // correction to obtain the correct base offset
//       auto res1 = (T2*) ((char *) dst[1] + base_offset / sizeof(T1) * sizeof(T2));
//       *res1 = x.second;
//     }
//   }

  template <int output_vec_size>
  C10_DEVICE void set_results_to_output(Array<arg_t, output_vec_size> value, Array<uint32_t, output_vec_size> base_offset) const {
    assert(final_output);
    #pragma unroll
    for (int i = 0; i < output_vec_size; i++) {
      set_results(reducer::project(value[i]), base_offset[i]);
    }
  }

  template <int output_vec_size>
  C10_DEVICE Array<arg_t, output_vec_size> global_reduce(Array<arg_t, output_vec_size> value, Array<arg_t, output_vec_size> *acc, char* shared_memory) const {
    using arg_vec_t = Array<arg_t, output_vec_size>;
    using out_ptr_vec_t = Array<out_scalar_t*, output_vec_size>;
    using offset_vec_t = Array<uint32_t, output_vec_size>;

    arg_vec_t* reduce_buffer = (arg_vec_t*)cta_buf;
    uint32_t output_idx = config.output_idx<output_vec_size>();
    offset_vec_t base_offsets;
    out_ptr_vec_t out;

    #pragma unroll
    for (int i = 0; i < output_vec_size; i++) {
      base_offsets[i] = output_calc.get(output_idx + i)[0];
      out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]);
    }

    bool should_store = config.should_store(output_idx);
    if (should_store) {
      uint32_t offset = config.staging_memory_offset(blockIdx.y);
      reduce_buffer[offset] = value;
    }

    __threadfence(); // make sure writes are globally visible
    __syncthreads(); // if multiple warps in this block wrote to staging, make sure they're all done
    bool is_last_block_done = mark_block_finished();

    if (is_last_block_done) {
      value = ident;
      if (config.should_block_x_reduce()) {
        uint32_t input_offset = threadIdx.x + threadIdx.y * blockDim.x;
        uint32_t step = blockDim.x * blockDim.y;
        for (; input_offset < config.ctas_per_output; input_offset += step) {
          uint32_t idx = config.staging_memory_offset(input_offset);
          arg_vec_t next = reduce_buffer[idx];
          #pragma unroll
          for (int i = 0; i < output_vec_size; i++) {
            value[i] = reducer::combine(value[i], next[i]);
          }
        }
      } else {
        uint32_t input_offset = threadIdx.y;
        uint32_t step = blockDim.y;
        for (; input_offset < config.ctas_per_output; input_offset += step) {
          uint32_t idx = config.staging_memory_offset(input_offset);
          arg_vec_t next = reduce_buffer[idx];
          #pragma unroll
          for (int i = 0; i < output_vec_size; i++) {
            value[i] = reducer::combine(value[i], next[i]);
          }
        }
      }
      value = block_y_reduce(value, shared_memory);
      if (config.should_block_x_reduce()) {
        value = block_x_reduce<output_vec_size>(value, shared_memory);
      }
      if (should_store) {
        if (accumulate) {
          #pragma unroll
          for (int i = 0; i < output_vec_size; i++) {
            value[i] = reducer::translate_idx(value[i], base_idx);
          }
        }

        if (acc == nullptr) {
          if (accumulate) {
            value = accumulate_in_output<output_vec_size>(out, value);
          }
          if (final_output) {
            set_results_to_output<output_vec_size>(value, base_offsets);
          } else {
            #pragma unroll
            for (int i = 0; i < output_vec_size; i++) {
              *(out[i]) = get_accumulated_output(out[i], value[i]);
            }
          }
        } else {
          if (accumulate) {
            #pragma unroll
            for (int i = 0; i < output_vec_size; i++) {
              value[i] = reducer::combine((*acc)[i], value[i]);
            }
          }
          if (final_output) {
            set_results_to_output<output_vec_size>(value, base_offsets);
          } else {
            *acc = value;
          }
        }
      }
    }

    return value;
  }
};

extern "C"
__launch_bounds__(512, 4)
__global__ void reduction_prod_kernel(ReduceJitOp r){
  r.run();
}
nvrtc: error: failed to open libnvrtc-builtins.so.12.1.
  Make sure that libnvrtc-builtins.so.12.1 is installed correctly.