diff --git a/examples/sycl/CMakeLists.txt b/examples/sycl/CMakeLists.txt index ae4c080102..5a618461fa 100644 --- a/examples/sycl/CMakeLists.txt +++ b/examples/sycl/CMakeLists.txt @@ -39,6 +39,7 @@ if(SYCL_INTEL_TARGET) add_subdirectory(08_bmg_gemm_f8) add_subdirectory(09_bmg_grouped_gemm_f8) add_subdirectory(10_bmg_grouped_gemm_mixed_dtype) + add_subdirectory(sdpa_bwd) endif() if (CUTLASS_ENABLE_SYCL) diff --git a/examples/sycl/sdpa_bwd/CMakeLists.txt b/examples/sycl/sdpa_bwd/CMakeLists.txt new file mode 100644 index 0000000000..8db8ae2445 --- /dev/null +++ b/examples/sycl/sdpa_bwd/CMakeLists.txt @@ -0,0 +1,10 @@ +set(TEST_GROUPS --groups=128) + +cutlass_example_add_executable( + sdpa_backward + sdpa_backward.cpp + cnpy.cpp +) +target_link_options(sdpa_backward PUBLIC "-lz") +set_target_properties(sdpa_backward PROPERTIES CXX_COMPILER_LAUNCHER "IGC_VectorAliasBBThreshold=10000" ) +set_target_properties(sdpa_backward PROPERTIES CXX_LINKER_LAUNCHER "IGC_VectorAliasBBThreshold=10000" ) diff --git a/examples/sycl/sdpa_bwd/cnpy.cpp b/examples/sycl/sdpa_bwd/cnpy.cpp new file mode 100644 index 0000000000..2d28578643 --- /dev/null +++ b/examples/sycl/sdpa_bwd/cnpy.cpp @@ -0,0 +1,340 @@ +//Copyright (C) 2011 Carl Rogers +//Released under MIT License +//license available in LICENSE file, or at http://www.opensource.org/licenses/mit-license.php + +#include"cnpy.h" +#include +#include +#include +#include +#include +#include +#include +#include + +char cnpy::BigEndianTest() { + int x = 1; + return (((char *)&x)[0]) ? '<' : '>'; +} + +char cnpy::map_type(const std::type_info& t) +{ + if(t == typeid(float) ) return 'f'; + if(t == typeid(double) ) return 'f'; + if(t == typeid(long double) ) return 'f'; + + if(t == typeid(int) ) return 'i'; + if(t == typeid(char) ) return 'i'; + if(t == typeid(short) ) return 'i'; + if(t == typeid(long) ) return 'i'; + if(t == typeid(long long) ) return 'i'; + + if(t == typeid(unsigned char) ) return 'u'; + if(t == typeid(unsigned short) ) return 'u'; + if(t == typeid(unsigned long) ) return 'u'; + if(t == typeid(unsigned long long) ) return 'u'; + if(t == typeid(unsigned int) ) return 'u'; + + if(t == typeid(bool) ) return 'b'; + + if(t == typeid(std::complex) ) return 'c'; + if(t == typeid(std::complex) ) return 'c'; + if(t == typeid(std::complex) ) return 'c'; + + else return '?'; +} + +template<> std::vector& cnpy::operator+=(std::vector& lhs, const std::string rhs) { + lhs.insert(lhs.end(),rhs.begin(),rhs.end()); + return lhs; +} + +template<> std::vector& cnpy::operator+=(std::vector& lhs, const char* rhs) { + //write in little endian + size_t len = strlen(rhs); + lhs.reserve(len); + for(size_t byte = 0; byte < len; byte++) { + lhs.push_back(rhs[byte]); + } + return lhs; +} + +void cnpy::parse_npy_header(unsigned char* buffer,size_t& word_size, std::vector& shape, bool& fortran_order) { + //std::string magic_string(buffer,6); + uint8_t major_version = *reinterpret_cast(buffer+6); + uint8_t minor_version = *reinterpret_cast(buffer+7); + uint16_t header_len = *reinterpret_cast(buffer+8); + std::string header(reinterpret_cast(buffer+9),header_len); + + size_t loc1, loc2; + + //fortran order + loc1 = header.find("fortran_order")+16; + fortran_order = (header.substr(loc1,4) == "True" ? true : false); + + //shape + loc1 = header.find("("); + loc2 = header.find(")"); + + std::regex num_regex("[0-9][0-9]*"); + std::smatch sm; + shape.clear(); + + std::string str_shape = header.substr(loc1+1,loc2-loc1-1); + while(std::regex_search(str_shape, sm, num_regex)) { + shape.push_back(std::stoi(sm[0].str())); + str_shape = sm.suffix().str(); + } + + //endian, word size, data type + //byte order code | stands for not applicable. + //not sure when this applies except for byte array + loc1 = header.find("descr")+9; + bool littleEndian = (header[loc1] == '<' || header[loc1] == '|' ? true : false); + assert(littleEndian); + + //char type = header[loc1+1]; + //assert(type == map_type(T)); + + std::string str_ws = header.substr(loc1+2); + loc2 = str_ws.find("'"); + word_size = atoi(str_ws.substr(0,loc2).c_str()); +} + +void cnpy::parse_npy_header(FILE* fp, size_t& word_size, std::vector& shape, bool& fortran_order) { + char buffer[256]; + size_t res = fread(buffer,sizeof(char),11,fp); + if(res != 11) + throw std::runtime_error("parse_npy_header: failed fread"); + std::string header = fgets(buffer,256,fp); + assert(header[header.size()-1] == '\n'); + + size_t loc1, loc2; + + //fortran order + loc1 = header.find("fortran_order"); + if (loc1 == std::string::npos) + throw std::runtime_error("parse_npy_header: failed to find header keyword: 'fortran_order'"); + loc1 += 16; + fortran_order = (header.substr(loc1,4) == "True" ? true : false); + + //shape + loc1 = header.find("("); + loc2 = header.find(")"); + if (loc1 == std::string::npos || loc2 == std::string::npos) + throw std::runtime_error("parse_npy_header: failed to find header keyword: '(' or ')'"); + + std::regex num_regex("[0-9][0-9]*"); + std::smatch sm; + shape.clear(); + + std::string str_shape = header.substr(loc1+1,loc2-loc1-1); + while(std::regex_search(str_shape, sm, num_regex)) { + shape.push_back(std::stoi(sm[0].str())); + str_shape = sm.suffix().str(); + } + + //endian, word size, data type + //byte order code | stands for not applicable. + //not sure when this applies except for byte array + loc1 = header.find("descr"); + if (loc1 == std::string::npos) + throw std::runtime_error("parse_npy_header: failed to find header keyword: 'descr'"); + loc1 += 9; + bool littleEndian = (header[loc1] == '<' || header[loc1] == '|' ? true : false); + assert(littleEndian); + + //char type = header[loc1+1]; + //assert(type == map_type(T)); + + std::string str_ws = header.substr(loc1+2); + loc2 = str_ws.find("'"); + word_size = atoi(str_ws.substr(0,loc2).c_str()); +} + +void cnpy::parse_zip_footer(FILE* fp, uint16_t& nrecs, size_t& global_header_size, size_t& global_header_offset) +{ + std::vector footer(22); + fseek(fp,-22,SEEK_END); + size_t res = fread(&footer[0],sizeof(char),22,fp); + if(res != 22) + throw std::runtime_error("parse_zip_footer: failed fread"); + + uint16_t disk_no, disk_start, nrecs_on_disk, comment_len; + disk_no = *(uint16_t*) &footer[4]; + disk_start = *(uint16_t*) &footer[6]; + nrecs_on_disk = *(uint16_t*) &footer[8]; + nrecs = *(uint16_t*) &footer[10]; + global_header_size = *(uint32_t*) &footer[12]; + global_header_offset = *(uint32_t*) &footer[16]; + comment_len = *(uint16_t*) &footer[20]; + + assert(disk_no == 0); + assert(disk_start == 0); + assert(nrecs_on_disk == nrecs); + assert(comment_len == 0); +} + +cnpy::NpyArray load_the_npy_file(FILE* fp) { + std::vector shape; + size_t word_size; + bool fortran_order; + cnpy::parse_npy_header(fp,word_size,shape,fortran_order); + + cnpy::NpyArray arr(shape, word_size, fortran_order); + size_t nread = fread(arr.data(),1,arr.num_bytes(),fp); + if(nread != arr.num_bytes()) + throw std::runtime_error("load_the_npy_file: failed fread"); + return arr; +} + +cnpy::NpyArray load_the_npz_array(FILE* fp, uint32_t compr_bytes, uint32_t uncompr_bytes) { + + std::vector buffer_compr(compr_bytes); + std::vector buffer_uncompr(uncompr_bytes); + size_t nread = fread(&buffer_compr[0],1,compr_bytes,fp); + if(nread != compr_bytes) + throw std::runtime_error("load_the_npy_file: failed fread"); + + int err; + z_stream d_stream; + + d_stream.zalloc = Z_NULL; + d_stream.zfree = Z_NULL; + d_stream.opaque = Z_NULL; + d_stream.avail_in = 0; + d_stream.next_in = Z_NULL; + err = inflateInit2(&d_stream, -MAX_WBITS); + + d_stream.avail_in = compr_bytes; + d_stream.next_in = &buffer_compr[0]; + d_stream.avail_out = uncompr_bytes; + d_stream.next_out = &buffer_uncompr[0]; + + err = inflate(&d_stream, Z_FINISH); + err = inflateEnd(&d_stream); + + std::vector shape; + size_t word_size; + bool fortran_order; + cnpy::parse_npy_header(&buffer_uncompr[0],word_size,shape,fortran_order); + + cnpy::NpyArray array(shape, word_size, fortran_order); + + size_t offset = uncompr_bytes - array.num_bytes(); + memcpy(array.data(),&buffer_uncompr[0]+offset,array.num_bytes()); + + return array; +} + +cnpy::npz_t cnpy::npz_load(std::string fname) { + FILE* fp = fopen(fname.c_str(),"rb"); + + if(!fp) { + throw std::runtime_error("npz_load: Error! Unable to open file "+fname+"!"); + } + + cnpy::npz_t arrays; + + while(1) { + std::vector local_header(30); + size_t headerres = fread(&local_header[0],sizeof(char),30,fp); + if(headerres != 30) + throw std::runtime_error("npz_load: failed fread"); + + //if we've reached the global header, stop reading + if(local_header[2] != 0x03 || local_header[3] != 0x04) break; + + //read in the variable name + uint16_t name_len = *(uint16_t*) &local_header[26]; + std::string varname(name_len,' '); + size_t vname_res = fread(&varname[0],sizeof(char),name_len,fp); + if(vname_res != name_len) + throw std::runtime_error("npz_load: failed fread"); + + //erase the lagging .npy + varname.erase(varname.end()-4,varname.end()); + + //read in the extra field + uint16_t extra_field_len = *(uint16_t*) &local_header[28]; + if(extra_field_len > 0) { + std::vector buff(extra_field_len); + size_t efield_res = fread(&buff[0],sizeof(char),extra_field_len,fp); + if(efield_res != extra_field_len) + throw std::runtime_error("npz_load: failed fread"); + } + + uint16_t compr_method = *reinterpret_cast(&local_header[0]+8); + uint32_t compr_bytes = *reinterpret_cast(&local_header[0]+18); + uint32_t uncompr_bytes = *reinterpret_cast(&local_header[0]+22); + + if(compr_method == 0) {arrays[varname] = load_the_npy_file(fp);} + else {arrays[varname] = load_the_npz_array(fp,compr_bytes,uncompr_bytes);} + } + + fclose(fp); + return arrays; +} + +cnpy::NpyArray cnpy::npz_load(std::string fname, std::string varname) { + FILE* fp = fopen(fname.c_str(),"rb"); + + if(!fp) throw std::runtime_error("npz_load: Unable to open file "+fname); + + while(1) { + std::vector local_header(30); + size_t header_res = fread(&local_header[0],sizeof(char),30,fp); + if(header_res != 30) + throw std::runtime_error("npz_load: failed fread"); + + //if we've reached the global header, stop reading + if(local_header[2] != 0x03 || local_header[3] != 0x04) break; + + //read in the variable name + uint16_t name_len = *(uint16_t*) &local_header[26]; + std::string vname(name_len,' '); + size_t vname_res = fread(&vname[0],sizeof(char),name_len,fp); + if(vname_res != name_len) + throw std::runtime_error("npz_load: failed fread"); + vname.erase(vname.end()-4,vname.end()); //erase the lagging .npy + + //read in the extra field + uint16_t extra_field_len = *(uint16_t*) &local_header[28]; + fseek(fp,extra_field_len,SEEK_CUR); //skip past the extra field + + uint16_t compr_method = *reinterpret_cast(&local_header[0]+8); + uint32_t compr_bytes = *reinterpret_cast(&local_header[0]+18); + uint32_t uncompr_bytes = *reinterpret_cast(&local_header[0]+22); + + if(vname == varname) { + NpyArray array = (compr_method == 0) ? load_the_npy_file(fp) : load_the_npz_array(fp,compr_bytes,uncompr_bytes); + fclose(fp); + return array; + } + else { + //skip past the data + uint32_t size = *(uint32_t*) &local_header[22]; + fseek(fp,size,SEEK_CUR); + } + } + + fclose(fp); + + //if we get here, we haven't found the variable in the file + throw std::runtime_error("npz_load: Variable name "+varname+" not found in "+fname); +} + +cnpy::NpyArray cnpy::npy_load(std::string fname) { + + FILE* fp = fopen(fname.c_str(), "rb"); + + if(!fp) throw std::runtime_error("npy_load: Unable to open file "+fname); + + NpyArray arr = load_the_npy_file(fp); + + fclose(fp); + return arr; +} + + + diff --git a/examples/sycl/sdpa_bwd/cnpy.h b/examples/sycl/sdpa_bwd/cnpy.h new file mode 100644 index 0000000000..0d3bb4c3c2 --- /dev/null +++ b/examples/sycl/sdpa_bwd/cnpy.h @@ -0,0 +1,269 @@ +//Copyright (C) 2011 Carl Rogers +//Released under MIT License +//license available in LICENSE file, or at http://www.opensource.org/licenses/mit-license.php + +#ifndef LIBCNPY_H_ +#define LIBCNPY_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace cnpy { + + struct NpyArray { + NpyArray(const std::vector& _shape, size_t _word_size, bool _fortran_order) : + shape(_shape), word_size(_word_size), fortran_order(_fortran_order) + { + num_vals = 1; + for(size_t i = 0;i < shape.size();i++) num_vals *= shape[i]; + data_holder = std::shared_ptr>( + new std::vector(num_vals * word_size)); + } + + NpyArray() : shape(0), word_size(0), fortran_order(0), num_vals(0) { } + + template + T* data() { + return reinterpret_cast(&(*data_holder)[0]); + } + + template + const T* data() const { + return reinterpret_cast(&(*data_holder)[0]); + } + + template + std::vector as_vec() const { + const T* p = data(); + return std::vector(p, p+num_vals); + } + + size_t num_bytes() const { + return data_holder->size(); + } + + std::shared_ptr> data_holder; + std::vector shape; + size_t word_size; + bool fortran_order; + size_t num_vals; + }; + + using npz_t = std::map; + + char BigEndianTest(); + char map_type(const std::type_info& t); + template std::vector create_npy_header(const std::vector& shape); + void parse_npy_header(FILE* fp,size_t& word_size, std::vector& shape, bool& fortran_order); + void parse_npy_header(unsigned char* buffer,size_t& word_size, std::vector& shape, bool& fortran_order); + void parse_zip_footer(FILE* fp, uint16_t& nrecs, size_t& global_header_size, size_t& global_header_offset); + npz_t npz_load(std::string fname); + NpyArray npz_load(std::string fname, std::string varname); + NpyArray npy_load(std::string fname); + + template std::vector& operator+=(std::vector& lhs, const T rhs) { + //write in little endian + for(size_t byte = 0; byte < sizeof(T); byte++) { + char val = *((char*)&rhs+byte); + lhs.push_back(val); + } + return lhs; + } + + template<> std::vector& operator+=(std::vector& lhs, const std::string rhs); + template<> std::vector& operator+=(std::vector& lhs, const char* rhs); + + + template void npy_save(std::string fname, const T* data, const std::vector shape, std::string mode = "w") { + FILE* fp = NULL; + std::vector true_data_shape; //if appending, the shape of existing + new data + + if(mode == "a") fp = fopen(fname.c_str(),"r+b"); + + if(fp) { + //file exists. we need to append to it. read the header, modify the array size + size_t word_size; + bool fortran_order; + parse_npy_header(fp,word_size,true_data_shape,fortran_order); + assert(!fortran_order); + + if(word_size != sizeof(T)) { + std::cout<<"libnpy error: "< header = create_npy_header(true_data_shape); + size_t nels = std::accumulate(shape.begin(),shape.end(),1,std::multiplies()); + + fseek(fp,0,SEEK_SET); + fwrite(&header[0],sizeof(char),header.size(),fp); + fseek(fp,0,SEEK_END); + fwrite(data,sizeof(T),nels,fp); + fclose(fp); + } + + template void npz_save(std::string zipname, std::string fname, const T* data, const std::vector& shape, std::string mode = "w") + { + //first, append a .npy to the fname + fname += ".npy"; + + //now, on with the show + FILE* fp = NULL; + uint16_t nrecs = 0; + size_t global_header_offset = 0; + std::vector global_header; + + if(mode == "a") fp = fopen(zipname.c_str(),"r+b"); + + if(fp) { + //zip file exists. we need to add a new npy file to it. + //first read the footer. this gives us the offset and size of the global header + //then read and store the global header. + //below, we will write the the new data at the start of the global header then append the global header and footer below it + size_t global_header_size; + parse_zip_footer(fp,nrecs,global_header_size,global_header_offset); + fseek(fp,global_header_offset,SEEK_SET); + global_header.resize(global_header_size); + size_t res = fread(&global_header[0],sizeof(char),global_header_size,fp); + if(res != global_header_size){ + throw std::runtime_error("npz_save: header read error while adding to existing zip"); + } + fseek(fp,global_header_offset,SEEK_SET); + } + else { + fp = fopen(zipname.c_str(),"wb"); + } + + std::vector npy_header = create_npy_header(shape); + + size_t nels = std::accumulate(shape.begin(),shape.end(),1,std::multiplies()); + size_t nbytes = nels*sizeof(T) + npy_header.size(); + + //get the CRC of the data to be added + uint32_t crc = crc32(0L,(uint8_t*)&npy_header[0],npy_header.size()); + crc = crc32(crc,(uint8_t*)data,nels*sizeof(T)); + + //build the local header + std::vector local_header; + local_header += "PK"; //first part of sig + local_header += (uint16_t) 0x0403; //second part of sig + local_header += (uint16_t) 20; //min version to extract + local_header += (uint16_t) 0; //general purpose bit flag + local_header += (uint16_t) 0; //compression method + local_header += (uint16_t) 0; //file last mod time + local_header += (uint16_t) 0; //file last mod date + local_header += (uint32_t) crc; //crc + local_header += (uint32_t) nbytes; //compressed size + local_header += (uint32_t) nbytes; //uncompressed size + local_header += (uint16_t) fname.size(); //fname length + local_header += (uint16_t) 0; //extra field length + local_header += fname; + + //build global header + global_header += "PK"; //first part of sig + global_header += (uint16_t) 0x0201; //second part of sig + global_header += (uint16_t) 20; //version made by + global_header.insert(global_header.end(),local_header.begin()+4,local_header.begin()+30); + global_header += (uint16_t) 0; //file comment length + global_header += (uint16_t) 0; //disk number where file starts + global_header += (uint16_t) 0; //internal file attributes + global_header += (uint32_t) 0; //external file attributes + global_header += (uint32_t) global_header_offset; //relative offset of local file header, since it begins where the global header used to begin + global_header += fname; + + //build footer + std::vector footer; + footer += "PK"; //first part of sig + footer += (uint16_t) 0x0605; //second part of sig + footer += (uint16_t) 0; //number of this disk + footer += (uint16_t) 0; //disk where footer starts + footer += (uint16_t) (nrecs+1); //number of records on this disk + footer += (uint16_t) (nrecs+1); //total number of records + footer += (uint32_t) global_header.size(); //nbytes of global headers + footer += (uint32_t) (global_header_offset + nbytes + local_header.size()); //offset of start of global headers, since global header now starts after newly written array + footer += (uint16_t) 0; //zip file comment length + + //write everything + fwrite(&local_header[0],sizeof(char),local_header.size(),fp); + fwrite(&npy_header[0],sizeof(char),npy_header.size(),fp); + fwrite(data,sizeof(T),nels,fp); + fwrite(&global_header[0],sizeof(char),global_header.size(),fp); + fwrite(&footer[0],sizeof(char),footer.size(),fp); + fclose(fp); + } + + template void npy_save(std::string fname, const std::vector data, std::string mode = "w") { + std::vector shape; + shape.push_back(data.size()); + npy_save(fname, &data[0], shape, mode); + } + + template void npz_save(std::string zipname, std::string fname, const std::vector data, std::string mode = "w") { + std::vector shape; + shape.push_back(data.size()); + npz_save(zipname, fname, &data[0], shape, mode); + } + + template std::vector create_npy_header(const std::vector& shape) { + + std::vector dict; + dict += "{'descr': '"; + dict += BigEndianTest(); + dict += map_type(typeid(T)); + dict += std::to_string(sizeof(T)); + dict += "', 'fortran_order': False, 'shape': ("; + dict += std::to_string(shape[0]); + for(size_t i = 1;i < shape.size();i++) { + dict += ", "; + dict += std::to_string(shape[i]); + } + if(shape.size() == 1) dict += ","; + dict += "), }"; + //pad with spaces so that preamble+dict is modulo 16 bytes. preamble is 10 bytes. dict needs to end with \n + int remainder = 16 - (10 + dict.size()) % 16; + dict.insert(dict.end(),remainder,' '); + dict.back() = '\n'; + + std::vector header; + header += (char) 0x93; + header += "NUMPY"; + header += (char) 0x01; //major version of numpy format + header += (char) 0x00; //minor version of numpy format + header += (uint16_t) dict.size(); + header.insert(header.end(),dict.begin(),dict.end()); + + return header; + } + + +} + +#endif diff --git a/examples/sycl/sdpa_bwd/params.hpp b/examples/sycl/sdpa_bwd/params.hpp new file mode 100644 index 0000000000..9655fa34c1 --- /dev/null +++ b/examples/sycl/sdpa_bwd/params.hpp @@ -0,0 +1,476 @@ +#pragma once +#include +#include +#include +using namespace cute; + +template +struct FAKernel { + /* + Q BATCH,NUM_HEAD_Q,SEQ_LEN_QO,HEAD_SIZE_QK + K BATCH,NUM_HEAD_KV,SEQ_LEN_KV,HEAD_SIZE_QK + V BATCH,NUM_HEAD_KV,SEQ_LEN_KV,HEAD_SIZE_VO + P BATCH,NUM_HEAD_Q,SEQ_LEN_QO,SEQ_LEN_KV + O BATCH,NUM_HEAD_Q,SEQ_LEN_QO,HEAD_SIZE_VO + */ + using DType = T_; + using VType = float; // accumulation + using MMA_Atom_ARCH = std::conditional_t< + std::is_same_v, + MMA_Atom, + MMA_Atom>; + static constexpr int kHeadDim = kHeadDim_; + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kBlockK = kBlockK_; + static constexpr int kNSGs = kNSGs_; + // using SubgroupLayout = Layout, _1, _1>, Stride<_1, _1, _1>>; + static constexpr int AtomLayoutMSdP = 16 *kNSGs / kBlockN; + // static constexpr int AtomLayoutNdKV = 16 *kNSGs / kHeadDim; + static constexpr int AtomLayoutNdKV = 2; + static constexpr int AtomLayoutMdQ = kBlockM / 32; + // static constexpr int AtomLayoutMSdP = 4; + // static constexpr int AtomLayoutNdKV = 2; + // static constexpr int AtomLayoutMdQ = 2; + using SubgroupLayoutSdP = Layout, Int, _1>>; + using SubgroupLayoutdKV = Layout, Int, _1>>; + using SubgroupLayoutdQ = Layout, Int, _1>>; + // static_assert(16 *AtomLayoutMSdP == kBlockM); + // static_assert(32 *kNSGs / AtomLayoutMSdP == kBlockN); + // static_assert(kBlockK == 32); + // using TileShapeSdP = Tile, Int<16 * kNSGs / AtomLayoutMSdP>, Int>; + using TileShapeSdP = Tile, Int, Int>; + static_assert(size<0>(TileShapeSdP{}) == kBlockM); + static_assert(size<1>(TileShapeSdP{}) == kBlockN); + // static_assert(size<2>(TileShapeSdP{}) == kBlockK); + // using TileShapedKV = Tile, Int<32 * kNSGs / AtomLayoutNdKV>, Int>; + using TileShapedKV = Tile, Int, Int>; + static_assert(size<0>(TileShapedKV{}) == kBlockN); + static_assert(size<1>(TileShapedKV{}) == kHeadDim); + // static_assert(size<2>(TileShapedKV{}) == kBlockK); + // using TileShapedQ = Tile, Int<32 * kNSGs / AtomLayoutMdQ>, Int>; + using TileShapedQ = Tile, Int, Int>; + static_assert(size<0>(TileShapedQ{}) == kBlockM); + static_assert(size<1>(TileShapedQ{}) == kHeadDim); + // static_assert(size<2>(TileShapedQ{}) == kBlockK); + + // using SubgroupLayout = Layout, Stride<_1, _1, _1>>; + // using TileShapeMSdP = Shape, Int, Int>; + using TiledMmaSdP = typename TiledMMAHelper, + SubgroupLayoutSdP>::TiledMMA; + + using TiledMmadKV = typename TiledMMAHelper, + SubgroupLayoutdKV>::TiledMMA; + + using TiledMmadQ = typename TiledMMAHelper, + SubgroupLayoutdQ>::TiledMMA; + static constexpr auto bP = Int<2>{}; // Pipeline + + using StrideR = cute::tuple>; + using StrideC = cute::tuple, long>; + + // for load Q and Kt in S=QKt + using TiledLoadQ = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + Layout>{}, // Thr layout 1x16 k-major + Layout>{})); // Val layout 16x1 + using TiledLoadKt = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + Layout>{}, // Thr layout 1x16 n-major + Layout>{})); // Val layout 16x1 + + // for load dO and Vt in dP=dO*Vt + using TiledLoaddO = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + Layout>{}, // Thr layout 1x16 k-major + Layout>{})); // Val layout 16x1 + + using TiledLoadV = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + Layout>{}, // Thr layout 1x16 n-major + Layout>{})); // Val layout 16x1 + + // for load Pt and dO in dV=Pt*dO + using TiledLoadPt = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + Layout>{}, // Thr layout 1x16 m-major + Layout>{})); // // Val layout 8x1 + using TiledLoaddOt = decltype(make_tiled_copy( + Copy_Atom, DType>{}, // should be V here + Layout>{}, // Thr layout 1x16 n-major + Layout>{})); // val layout 16x1 + + // for load dP, K and dQ in dQ=dP*K + using TiledLoaddP = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + Layout>{}, // Thr layout 1x16 k-major + Layout>{})); // val layout 16x1 + using TiledLoadK = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + Layout>{}, // Thr layout 1x16 n-major + Layout>{})); // val layout 16x1 + + using TiledLoaddQ = decltype(make_tiled_copy( + Copy_Atom, VType>{}, + Layout>{}, // Thr layout 1x16 n-major + Layout>{})); // val layout 8x1 + + // for load dPt, Q in dK=dPt*Q + using TiledLoaddPt = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + Layout>{}, // Thr layout 1x16 k-major + Layout>{})); // Val layout 16x1 + using TiledLoadQt = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + Layout>{}, // Thr layout 1x16 n-major + Layout>{})); // Val layout 16x1 + + // for save S in S=QKt and P + using TiledSaveS = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + Layout>{}, // Thr layout 1x16 n-major + Layout>{})); // Val layout 8x1 + // for save dP in dP=dO*Vt + using TiledSavedP = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + Layout>{}, // Thr layout 1x16 n-major + Layout>{})); // Val layout 8x1 + // for save dV in dV=Pt*dO + using TiledSavedV = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + Layout>{}, // Thr layout 1x16 n-major + Layout>{})); // Val layout 8x1 + // for save dQ in dQ=dP*K + using TiledSavedQ = decltype(make_tiled_copy( + Copy_Atom, VType>{}, + Layout>{}, // Thr layout 1x16 n-major + Layout>{})); // val layout 8x1 + // for save dK=dPt*Q + using TiledSavedK = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + Layout>{}, // Thr layout 1x16 n-major + Layout>{})); // Val layout 8x1 + + static constexpr int SubgroupSize = 16; + static constexpr int smem_size = 0; + + FAKernel() {} +}; +template +struct COPY_Trait { + using StrideR = cute::tuple>; + using StrideC = cute::tuple, long>; + // using VEC_COPY = Copy_Atom, DType>; + // using LOAD_2D_16x16_N_R = std::conditional_t< + // is_even_n, + // Copy_Atom, DType>, + // VEC_COPY>; + // using LOAD_2D_16x16_T_R = std::conditional_t< + // is_even_n, + // Copy_Atom, DType>, + // VEC_COPY>; + // using SAVE_2D_8x16_N_R = std::conditional_t< + // is_even_n, + // Copy_Atom, DType>, + // VEC_COPY>; + // for load Q and Kt in S=QKt + using TiledLoadQ = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + Layout>{}, // Thr layout 1x16 k-major + Layout>{})); // Val layout 16x1 + using TiledLoadKt = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + Layout>{}, // Thr layout 1x16 n-major + Layout>{})); // Val layout 16x1 + + // for load dO and Vt in dP=dO*Vt + using TiledLoaddO = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + Layout>{}, // Thr layout 1x16 k-major + Layout>{})); // Val layout 16x1 + + using TiledLoadV = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + Layout>{}, // Thr layout 1x16 n-major + Layout>{})); // Val layout 16x1 + + // for load Pt and dO in dV=Pt*dO + using TiledLoadPt = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + Layout>{}, // Thr layout 1x16 m-major + Layout>{})); // // Val layout 8x1 + using TiledLoaddOt = decltype(make_tiled_copy( + Copy_Atom, DType>{}, // should be V here + Layout>{}, // Thr layout 1x16 n-major + Layout>{})); // val layout 16x1 + + // for load dP, K and dQ in dQ=dP*K + using TiledLoaddP = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + Layout>{}, // Thr layout 1x16 k-major + Layout>{})); // val layout 16x1 + using TiledLoadK = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + Layout>{}, // Thr layout 1x16 n-major + Layout>{})); // val layout 16x1 + + using TiledLoaddQ = decltype(make_tiled_copy( + Copy_Atom, VType>{}, + Layout>{}, // Thr layout 1x16 n-major + Layout>{})); // val layout 8x1 + + // for load dPt, Q in dK=dPt*Q + using TiledLoaddPt = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + Layout>{}, // Thr layout 1x16 k-major + Layout>{})); // Val layout 16x1 + using TiledLoadQt = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + Layout>{}, // Thr layout 1x16 n-major + Layout>{})); // Val layout 16x1 + + // for save S in S=QKt and P + using TiledSaveS = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + Layout>{}, // Thr layout 1x16 n-major + Layout>{})); // Val layout 8x1 + // for save dP in dP=dO*Vt + using TiledSavedP = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + Layout>{}, // Thr layout 1x16 n-major + Layout>{})); // Val layout 8x1 + // for save dV in dV=Pt*dO + using TiledSavedV = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + Layout>{}, // Thr layout 1x16 n-major + Layout>{})); // Val layout 8x1 + // for save dQ in dQ=dP*K + using TiledSavedQ = decltype(make_tiled_copy( + Copy_Atom, VType>{}, + Layout>{}, // Thr layout 1x16 n-major + Layout>{})); // val layout 8x1 + // for save dK=dPt*Q + using TiledSavedK = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + Layout>{}, // Thr layout 1x16 n-major + Layout>{})); // Val layout 8x1 +}; + +using index_t = uint64_t; + +template +struct Param { + Param(const T *dO, + const T *o, + const T *q, + const T *k, + const T *v, + const float *lse, + float *odo, + float *dqaccum, + T *dq, + T *dk, + T *dv, + T *s, + T *dp, + T *pb, + const float softmax_scale) + : do_ptr(dO), + o_ptr(o), + q_ptr(q), + k_ptr(k), + v_ptr(v), + lse_ptr(lse), + odo_ptr(odo), + dqaccum_ptr(dqaccum), + dq_ptr(dq), + dk_ptr(dk), + dv_ptr(dv), + s_ptr(s), + dp_ptr(dp), + pb_ptr(pb), + scale_softmax(softmax_scale), + scale_softmax_log2(softmax_scale * M_LOG2E), + is_bhsd(true) {} + // read only + const T *do_ptr; + const T *o_ptr; + const T *q_ptr; + const T *k_ptr; + const T *v_ptr; + const float *lse_ptr; + const float scale_softmax; + const float scale_softmax_log2; + // write + float *odo_ptr; + float *dqaccum_ptr; + T *dq_ptr; + T *dk_ptr; + T *dv_ptr; + T *s_ptr; + T *dp_ptr; + T *pb_ptr; + + // const dimension + int batch; + int num_head_q; + int num_head_kv; + int seq_len_q; + int seq_len_q_pad; + int seq_len_kv; + int seq_len_kv_pad; + int head_dim; + int n_block; + int tail_n; + int m_block; + int tail_m; + int q_r_stride; + int q_h_stride; + int q_b_stride; + + int k_r_stride; + int k_h_stride; + int k_b_stride; + + int v_r_stride; + int v_h_stride; + int v_b_stride; + + int o_r_stride; + int o_h_stride; + int o_b_stride; + + int s_r_stride; + int s_s_stride; + int s_b_stride; + + int dq_r_stride; + int dq_h_stride; + int dq_b_stride; + /* + * input output layout + * true batch, numhead, seqlen, headsize + * false batch, seqlen, numhead, headsize + */ + bool is_bhsd; +}; + +template +struct Boffset { + Boffset(Param ¶m_) : param(param_) {} + index_t q_offset(const index_t b_id, const index_t h_id, const index_t s_id) { + return b_id * param.q_b_stride + h_id * param.q_h_stride + s_id * param.q_r_stride; + } + index_t k_offset(const index_t b_id, const index_t h_id, const index_t s_id) { + return b_id * param.k_b_stride + h_id * param.k_h_stride + s_id * param.k_r_stride; + } + index_t v_offset(const index_t b_id, const index_t h_id, const index_t s_id) { + return b_id * param.v_b_stride + h_id * param.v_h_stride + s_id * param.v_r_stride; + } + index_t ps_offset(const index_t b_id, const index_t h_id, + const index_t sq_id, const index_t sk_id) { + return b_id * param.s_b_stride + + h_id * param.s_s_stride + + sq_id * param.s_r_stride + sk_id; + } + index_t lse_offset(const index_t b_id, const index_t h_id, const index_t s_id) { + return b_id * param.seq_len_q * param.num_head_q + h_id * param.seq_len_q + s_id; + } + + index_t o_offset(const index_t b_id, const index_t h_id, const index_t s_id) { + return b_id * param.o_b_stride + h_id * param.o_h_stride + s_id * param.o_r_stride; + } + + index_t dq_offset(const index_t b_id, const index_t h_id, const index_t s_id) { + return b_id * param.dq_b_stride + h_id * param.dq_h_stride + s_id * param.dq_r_stride; + } + Param ¶m; +}; + +// for debug +template +void setup_bhsd_stride(Param ¶m) { + param.q_r_stride = param.head_dim; + param.q_h_stride = param.seq_len_q * param.head_dim; + param.q_b_stride = param.num_head_q * param.seq_len_q * param.head_dim; + + // param.dq_r_stride = param.head_dim; + // param.dq_h_stride = param.seq_len_q * param.head_dim; + // param.dq_b_stride = param.num_head_q * param.seq_len_q * param.head_dim; + + param.k_r_stride = param.head_dim; + param.k_h_stride = param.seq_len_kv * param.head_dim; + param.k_b_stride = param.num_head_kv * param.seq_len_kv * param.head_dim; + + // param.dk_r_stride = param.head_dim; + // param.dk_h_stride = param.seq_len_kv * param.head_dim; + // param.dk_b_stride = param.num_head_kv * param.seq_len_kv * param.head_dim; + + param.v_r_stride = param.head_dim; + param.v_h_stride = param.seq_len_kv * param.head_dim; + param.v_b_stride = param.num_head_kv * param.seq_len_kv * param.head_dim; + + // param.dv_r_stride = param.head_dim; + // param.dv_h_stride = param.seq_len_kv * param.head_dim; + // param.dv_b_stride = param.num_head_kv * param.seq_len_kv * param.head_dim; + + param.o_r_stride = param.head_dim; + param.o_h_stride = param.seq_len_q * param.head_dim; + param.o_b_stride = param.num_head_q * param.seq_len_q * param.head_dim; + + // param.do_r_stride = param.head_dim; + // param.do_h_stride = param.seq_len_q * param.head_dim; + // param.do_b_stride = param.num_head_q * param.seq_len_q * param.head_dim; + param.s_r_stride = param.seq_len_kv_pad; + param.s_s_stride = param.seq_len_q_pad * param.seq_len_kv_pad; + param.s_b_stride = param.num_head_q * param.seq_len_q_pad * param.seq_len_kv_pad; + + param.dq_r_stride = param.head_dim; + param.dq_h_stride = param.seq_len_q_pad * param.head_dim; + param.dq_b_stride = param.num_head_q * param.seq_len_q_pad * param.head_dim; +} + +template +void setup_bshd_stride(Param ¶m) { + param.q_r_stride = param.num_head_q * param.head_dim; + param.q_h_stride = param.head_dim; + param.q_b_stride = param.num_head_q * param.seq_len_q * param.head_dim; + + // param.dq_r_stride = param.head_dim; + // param.dq_h_stride = param.seq_len_q * param.head_dim; + // param.dq_b_stride = param.num_head_q * param.seq_len_q * param.head_dim; + + param.k_r_stride = param.num_head_kv * param.head_dim; + param.k_h_stride = param.head_dim; + param.k_b_stride = param.num_head_kv * param.seq_len_kv * param.head_dim; + + // param.dk_r_stride = param.head_dim; + // param.dk_h_stride = param.seq_len_kv * param.head_dim; + // param.dk_b_stride = param.num_head_kv * param.seq_len_kv * param.head_dim; + + param.v_r_stride = param.num_head_kv * param.head_dim; + param.v_h_stride = param.head_dim; + param.v_b_stride = param.num_head_kv * param.seq_len_kv * param.head_dim; + + // param.dv_r_stride = param.head_dim; + // param.dv_h_stride = param.seq_len_kv * param.head_dim; + // param.dv_b_stride = param.num_head_kv * param.seq_len_kv * param.head_dim; + + param.o_r_stride = param.num_head_q * param.head_dim; + param.o_h_stride = param.head_dim; + param.o_b_stride = param.num_head_q * param.seq_len_q * param.head_dim; + + // param.do_r_stride = param.head_dim; + // param.do_h_stride = param.seq_len_q * param.head_dim; + // param.do_b_stride = param.num_head_q * param.seq_len_q * param.head_dim; + param.s_r_stride = param.seq_len_kv_pad; + param.s_s_stride = param.seq_len_q_pad * param.seq_len_kv_pad; + param.s_b_stride = param.num_head_q * param.seq_len_q_pad * param.seq_len_kv_pad; + + param.dq_r_stride = param.num_head_q * param.head_dim; + param.dq_h_stride = param.head_dim; + param.dq_b_stride = param.num_head_q * param.seq_len_q_pad * param.head_dim; +} diff --git a/examples/sycl/sdpa_bwd/sdpa_backward.cpp b/examples/sycl/sdpa_bwd/sdpa_backward.cpp new file mode 100644 index 0000000000..d4be7cebf6 --- /dev/null +++ b/examples/sycl/sdpa_bwd/sdpa_backward.cpp @@ -0,0 +1,1151 @@ +#include +#include +#include +#include + +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/print_error.hpp" +#include "cutlass/util/sycl_event_manager.hpp" +#include "cutlass/util/GPU_Clock.hpp" +#include "cnpy.h" +#include "sdpa_util.hpp" +#include "params.hpp" + + +void read_args(int argc, char**argv, int n, int64_t *p) { + if (argc >= n + 1) + sscanf(argv[n], "%ld", p); +} + +void +debug_info() { + print("block idx (%d,%d,%d) dim (%d,%d,%d) thread idx (%d,%d,%d) dim (%d,%d,%d)\n", + BlockIdxX(), BlockIdxY(), BlockIdxZ(), + GridDimX(), GridDimY(), GridDimZ(), + ThreadIdxX(), ThreadIdxY(), ThreadIdxZ(), + BlockDimX(), BlockDimY(), BlockDimZ()); +} + +template +void print_t(T t) { + print(t); + for (int i = 0; i < size(t); ++i) { + if (i % 8 == 0) + print("\n(%03d): ", i / 8); + print("%10.7f ", (float)t(i)); + } + print("\n"); +} + +template +void print_t_2d(T t) { + static_assert(rank(t) == 2, "Only support 2D Tensor"); + print(t); + for (int i = 0; i < size < 0>(t); ++i) { + print("\n(%03d): ", i); + for (int j = 0; j < size<1>(t); ++j) { + print("%10.7f ", (float)t(i,j)); + } + } + print("\n"); +} + +template +void print_d(T t) { + print(t); + for (int i = 0; i < size(t); ++i) { + if (i % 8 == 0) + print("\n(%03d): ", i / 8); + print("%10u ", t(i)); + } + print("\n"); +} + +using ProblemShapeRegular = cute::tuple; // batch, num_head_q,num_head_kv,seq_len_qo,seq_len_kv,head_size_qk,head_size_vo + +template +struct OPS_tobf16{ + template + auto operator()(Tensor &src){ + cutlass::NumericConverter< + T, float, cutlass::FloatRoundStyle::round_toward_zero> converter; + auto dst = make_tensor_like(src); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src); ++i) { + dst(i) = converter(src(i)); + } + return dst; + } +}; + +constexpr int tid = 0; +constexpr int bid = 0; + + +template +CUTLASS_DEVICE void gemm_ker(Tensor0 &tCrC, Tensor1 &tCrA, Tensor2 &tCrB, + Tensor3 &tAgA, Tensor4 &tArA, + Tensor5 &tBgB, Tensor6 &tBrB, TiledMma &tiled_mma, + TiledCopyA ©_a, TiledCopyB ©_b, + ThrCopyA &thr_copy_a, ThrCopyB &thr_copy_b) { + constexpr int barrier_scope = 2; + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<3>(tAgA); ++k) { + barrier_arrive(barrier_scope); + cute::copy(copy_a, tAgA(_, _, _, k), tArA); + cute::copy(copy_b, tBgB(_, _, _, k), tBrB); + cute::gemm(tiled_mma, tCrA, tCrB, tCrC); + barrier_wait(barrier_scope); + } +} + +template +CUTLASS_DEVICE void load_1colvec(Tensor0 ®, Tensor1 &mT, Tensor2 &coord_row) { + CUTLASS_PRAGMA_UNROLL + for (int mi = 0; mi < size(reg); ++mi) { + reg(mi) = mT(get<0>(coord_row(mi))); + } +} +template +CUTLASS_DEVICE auto convert_layout_acc_layout(Layout acc_layout) { + static_assert(decltype(size<0>(acc_layout))::value == 8); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape<_1>{}); // ((2, 2), MMA_M, MMA_N) + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); +} + +template +CUTLASS_DEVICE void scale_apply_exp2(Tensor &tensor, Tensor &max, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + CUTLASS_PRAGMA_UNROLL + for (int mi = 0; mi < size<0>(tensor); ++mi) { + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * M_LOG2E; + CUTLASS_PRAGMA_UNROLL + for (int ni = 0; ni < size<1>(tensor); ++ni) { + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + } + } +} + +template +CUTLASS_DEVICE void softmax_backward(Tensor0 &P, Tensor1 &dP_sum, Tensor2 &dP, const float scale) { + CUTLASS_PRAGMA_UNROLL + for (int mi = 0; mi < size<0>(dP); ++mi) { + CUTLASS_PRAGMA_UNROLL + for (int mj = 0; mj < size<1>(dP); ++mj) { + dP(mi, mj) = P(mi, mj) * (dP(mi, mj) - dP_sum(mi)) * scale; + } + } +} + +template +CUTLASS_DEVICE auto convert_type(Tensor const &tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + auto frag = convert_op(*reinterpret_cast *>(tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); +} + +template +CUTLASS_DEVICE auto convert_type(CVT &cvt, T0 &src, T1 &dst) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src); ++i) { + dst(i) = cvt(src(i)); + } +} + +template +void +dq_dk_dv_1colblock(Trait &trait, Param ¶m, + const int bidb, const int bidh, const int n_block, + const int tail_n = 0) { + using T = typename Trait::DType; + using V = typename Trait::VType; + constexpr int kHeadDim = Trait::kHeadDim; + constexpr int kBlockM = Trait::kBlockM; + constexpr int kBlockN = Trait::kBlockN; + constexpr int kBlockK = Trait::kBlockK; + auto sg = syclcompat::get_nd_item<1>().get_sub_group(); + auto group = syclcompat::get_nd_item<1>().get_group(); + auto first_thread_in_sg_idx = sg.get_group_linear_id() * trait.SubgroupSize; + auto bofst = Boffset(param); + + const index_t q_offset = bofst.q_offset(bidb, bidh, 0); + const index_t k_offset = bofst.k_offset(bidb, bidh, n_block * kBlockN); + const index_t v_offset = bofst.v_offset(bidb, bidh, n_block * kBlockN); + const index_t o_offset = bofst.o_offset(bidb, bidh, 0); + const index_t do_offset = bofst.o_offset(bidb, bidh, 0); + const index_t dv_offset = bofst.v_offset(bidb, bidh, n_block * kBlockN); + const index_t dq_offset = bofst.dq_offset(bidb, bidh, 0); + const index_t lse_offset = bofst.lse_offset(bidb, bidh, 0); + const index_t dpsum_offset = bofst.lse_offset(bidb, bidh, 0); + // buff offset + const index_t pb_offset = bidb * param.num_head_q * param.seq_len_kv_pad * kBlockM + + bidh * param.seq_len_kv_pad * kBlockM + n_block * kBlockN * kBlockM; + + const index_t s_offset = bofst.ps_offset(bidb, bidh, 0, n_block * kBlockN); + const index_t dp_offset = bofst.ps_offset(bidb, bidh, 0, n_block * kBlockN); + + const auto block_n_dim = tail_n == 0 ? Int{} : tail_n; + using Shape1 = Shape< + std::conditional_t, int>, + Int , Int<1>>; + using Shape2 = Shape< + Int , + std::conditional_t, int>, + Int<1>>; + Shape shapeQ = make_shape(kBlockM, Int{}, _1{}); + Shape1 shapeKt, shapeV; + Shape2 shapeK; + if constexpr(Is_even_MN) { + shapeKt = make_shape(Int{}, Int{}, _1{}); + shapeV = make_shape(Int{}, Int{}, _1{}); + shapeK = make_shape(Int{}, Int{}, _1{}); + } else { + shapeKt = make_shape(tail_n, Int{}, _1{}); + shapeV = make_shape(tail_n, Int{}, _1{}); + shapeK = make_shape(Int{}, tail_n, _1{}); + } + Shape shapeO = make_shape(kBlockM, Int{}, _1{}); + Shape shapeOt = make_shape(Int{}, kBlockM, _1{}); + + + Shape shapeSP = make_shape(kBlockM, block_n_dim, _1{}); + Shape shapedP = make_shape(kBlockM, block_n_dim, _1{}); + + Shape shapePt = make_shape(block_n_dim, kBlockM, _1{}); + Shape shapeQt = make_shape(Int{}, kBlockM, _1{}); + + Tensor mQ = make_tensor(make_gmem_ptr(param.q_ptr + q_offset), + make_layout( + shapeQ, + make_stride(param.q_r_stride, _1{}, _1{}))); + Tensor mKt = make_tensor(make_gmem_ptr(param.k_ptr + k_offset), + make_layout( + shapeKt, + make_stride(param.k_r_stride, _1{}, _1{}))); + Tensor mV = make_tensor(make_gmem_ptr(param.v_ptr + v_offset), + make_layout( + shapeV, + make_stride(param.v_r_stride, _1{}, _1{}))); + Tensor mdO = make_tensor(make_gmem_ptr(param.do_ptr + do_offset), + make_layout( + shapeO, + make_stride(param.o_r_stride, _1{}, _1{}))); + // intermediate buffer + Tensor mP = make_tensor(make_gmem_ptr(param.pb_ptr + pb_offset), + make_layout( + shapeSP, + make_stride(block_n_dim, _1{}, _1{}))); + Tensor mPt = make_tensor(make_gmem_ptr(param.pb_ptr + pb_offset), + make_layout( + shapePt, + make_stride(_1{}, block_n_dim, _1{}))); + Tensor mdOt = make_tensor(make_gmem_ptr(param.do_ptr + do_offset), + make_layout( + shapeOt, + make_stride(_1{}, param.o_r_stride, _1{}))); + Tensor mK = make_tensor(make_gmem_ptr(param.k_ptr + k_offset), + make_layout( + shapeK, + make_stride(_1{}, param.k_r_stride, _1{}))); + Tensor mdPt = make_tensor(make_gmem_ptr(param.pb_ptr + pb_offset), + make_layout( + shapePt, + make_stride(_1{}, block_n_dim, _1{}))); + Tensor mQt = make_tensor(make_gmem_ptr(param.q_ptr + q_offset), + make_layout( + shapeQt, + make_stride(_1{}, param.q_r_stride, _1{}))); + + Tensor mLSE = make_tensor(make_gmem_ptr(param.lse_ptr + lse_offset), + make_layout( + Shape>{}, + Stride<_1>{})); + Tensor mdPsum = make_tensor(make_gmem_ptr(param.odo_ptr + dpsum_offset), + make_layout( + Shape>{}, + Stride<_1>{})); + + Tensor mdV = make_tensor(make_gmem_ptr(param.dv_ptr + dv_offset), + make_layout( + shapeV, + make_stride(param.v_r_stride, _1{}, _1{}))); + Tensor mdP = make_tensor(make_gmem_ptr(param.pb_ptr + pb_offset), + make_layout( + shapedP, + make_stride(block_n_dim, _1{}, _1{}))); + Tensor mdQaccum = make_tensor(make_gmem_ptr(param.dqaccum_ptr + dq_offset), + make_layout( + shapeQ, + make_stride(param.dq_r_stride, _1{}, _1{}))); + Tensor mdK = make_tensor(make_gmem_ptr(param.dk_ptr+k_offset), + make_layout( + shapeKt, + make_stride(param.k_r_stride, _1{}, _1{}))); + + Tensor mS = make_tensor(make_gmem_ptr(param.s_ptr + s_offset), make_layout( + shapeSP, + make_stride(param.s_r_stride, _1{}, _1{}))); + Tensor mdPd = make_tensor(make_gmem_ptr(param.dp_ptr + s_offset), make_layout( + shapeSP, + make_stride(param.s_r_stride, _1{}, _1{}))); + + Shape tile_sdp = typename Trait::TileShapeSdP{}; + Shape tile_dkv = typename Trait::TileShapedKV{}; + Shape tile_dq = typename Trait::TileShapedQ{}; + + auto tileloadQ = typename Trait::TiledLoadQ{mQ}; + auto tileloadKt = typename Trait::TiledLoadKt{mKt}; + auto tileloaddO = typename Trait::TiledLoaddO{mdO}; + auto tileloadV = typename Trait::TiledLoadV{mV}; + auto tileloadPt = typename Trait::TiledLoadPt{mPt}; + auto tileloaddOt = typename Trait::TiledLoaddOt{mdOt}; // load dO as operand B for dV=Pt*dO + auto tileloaddP = typename Trait::TiledLoaddP{mdP}; + auto tileloadK = typename Trait::TiledLoadK{mK}; + auto tileloaddQ = typename Trait::TiledLoaddQ{mdQaccum}; + auto tileloaddPt = typename Trait::TiledLoaddPt{mdPt}; + auto tileloadQt = typename Trait::TiledLoadQt{mQt}; + + auto tilesaveP = typename Trait::TiledSaveS{mP}; // to internal buffer + auto tilesavedV = typename Trait::TiledSavedV{mdV}; + auto tilesavedP = typename Trait::TiledSavedP{mdP}; + auto tilesavedQ = typename Trait::TiledSavedQ{mdQaccum}; + auto tilesavedK = typename Trait::TiledSavedK{mdK}; + + + Tensor mQ_coord = cute::get_xe_tensor(shapeQ); + Tensor mKt_coord = cute::get_xe_tensor(shapeKt); + Tensor mV_coord = cute::get_xe_tensor(shapeV); + Tensor mdO_coord = cute::get_xe_tensor(shapeO); + Tensor mdOt_coord = cute::get_xe_tensor(shapeOt); + Tensor mdV_coord = cute::get_xe_tensor(shapeV); + Tensor mK_coord = cute::get_xe_tensor(shapeK); + Tensor mQt_coord = cute::get_xe_tensor(shapeQt); + + Tensor mS_coord = cute::get_xe_tensor(shapeSP); + Tensor mPt_coord = cute::get_xe_tensor(shapePt); + Tensor mdP_coord = cute::get_xe_tensor(shapedP); + + typename Trait::TiledMmaSdP tiled_mma_sdp; + typename Trait::TiledMmadKV tiled_mma_dkv; + typename Trait::TiledMmadQ tiled_mma_dq; + + auto thr_mma_sdp = tiled_mma_sdp.get_slice(first_thread_in_sg_idx); + auto thr_mma_dkv = tiled_mma_dkv.get_slice(first_thread_in_sg_idx); + auto thr_mma_dq = tiled_mma_dq.get_slice(first_thread_in_sg_idx); + + Tensor gQ = local_tile(mQ_coord, select<0, 2>(tile_sdp), make_coord(0, _, 0)); + Tensor gKt = local_tile(mKt_coord, select<1, 2>(tile_sdp), make_coord(0, _, 0)); + Tensor gdO = local_tile(mdO_coord, select<0, 2>(tile_sdp), make_coord(0, _, 0)); + Tensor gV = local_tile(mV_coord, select<1, 2>(tile_sdp), make_coord(0, _, 0)); + Tensor gPt = local_tile(mPt_coord, select<0, 2>(tile_dkv), make_coord(0, _, 0)); // load Pt + Tensor gdOt = local_tile(mdOt_coord, select<1, 2>(tile_dkv), make_coord(0, _, 0)); + Tensor gdPa = local_tile(mdP_coord, select<0, 2>(tile_dq), make_coord(0, _, 0)); // operand A dQ + Tensor gK = local_tile(mK_coord, select<1, 2>(tile_dq), make_coord(0, _, 0)); // operand B dQ + Tensor gdPt = local_tile(mPt_coord, select<0, 2>(tile_dkv), make_coord(0, _, 0)); // load dpt + Tensor gQt = local_tile(mQt_coord, select<1, 2>(tile_dkv), make_coord(0, _, 0)); // load Q as operand B + + Tensor gP = local_tile(mS_coord, select<0, 1>(tile_sdp), make_coord(0, 0, 0)); // dump P + Tensor gdP = local_tile(mdP_coord, select<0, 1>(tile_sdp), make_coord(0, 0, 0)); // dump dP + Tensor gdV = local_tile(mdV_coord, select<0, 1>(tile_dkv), make_coord(0, 0, 0)); // dump dV + Tensor gdQ = local_tile(mQ_coord, select<0, 1>(tile_dq), make_coord(0, 0, 0)); // dump dQ + Tensor gdK = local_tile(mK_coord, select<0, 1>(tile_dkv), make_coord(0, 0, 0)); // dump dK + + + Tensor tSgQ = thr_mma_sdp.partition_A(gQ); + Tensor tSgKt = thr_mma_sdp.partition_B(gKt); + Tensor tdPgdO = thr_mma_sdp.partition_A(gdO); + Tensor tdPgV = thr_mma_sdp.partition_B(gV); + Tensor tdVgPt = thr_mma_dkv.partition_A(gPt); + Tensor tdVgdOt = thr_mma_dkv.partition_B(gdOt); + Tensor tdQgdP = thr_mma_dq.partition_A(gdPa); + Tensor tdQgK = thr_mma_dq.partition_B(gK); + Tensor tdKgdPt = thr_mma_dkv.partition_A(gdPt); + Tensor tdKgQt = thr_mma_dkv.partition_B(gQt); + + Tensor tPgP = thr_mma_sdp.partition_C(gP); // save P to internal buffer + Tensor tdPgdP = thr_mma_sdp.partition_C(gdP); // save dP to internal buffer + Tensor tdVgdV = thr_mma_dkv.partition_C(gdV); // save to dv + Tensor tdQgdQ = thr_mma_dq.partition_C(gdQ); // save to dq + Tensor tdKgdK = thr_mma_dkv.partition_C(gdK); // save to dk + + + Tensor tSrQ = make_tensor(make_fragment_layout(tileloadQ, tSgQ(_,_,_,0).shape())); + Tensor tSrKt = make_tensor(make_fragment_layout(tileloadKt, tSgKt(_,_,_,0).shape())); + Tensor tdPrdO = make_tensor(make_fragment_layout(tileloaddO, tdPgdO(_,_,_,0).shape())); + Tensor tdPrV = make_tensor(make_fragment_layout(tileloadV, tdPgV(_,_,_,0).shape())); + Tensor tdVrPt = make_tensor(make_fragment_layout(tileloadPt, tdVgPt(_,_,_,0).shape())); + Tensor tdVrdOt = make_tensor(make_fragment_layout(tileloaddOt, tdVgdOt(_,_,_,0).shape())); + Tensor tdQrdP = make_tensor(make_fragment_layout(tileloaddP, tdQgdP(_,_,_,0).shape())); + Tensor tdQrK = make_tensor(make_fragment_layout(tileloadK, tdQgK(_,_,_,0).shape())); + Tensor tdKrdPt = make_tensor(make_fragment_layout(tileloaddPt, tdKgdPt(_,_,_,0).shape())); + Tensor tdKrQt = make_tensor(make_fragment_layout(tileloadQt, tdKgQt(_,_,_,0).shape())); + + ThrCopy thr_copy_q = tileloadQ.get_slice(syclcompat::local_id::x()); + ThrCopy thr_copy_kt = tileloadKt.get_slice(syclcompat::local_id::x()); + ThrCopy thr_copy_do = tileloaddO.get_slice(syclcompat::local_id::x()); + ThrCopy thr_copy_v = tileloadV.get_slice(syclcompat::local_id::x()); + ThrCopy thr_copy_pt = tileloadPt.get_slice(syclcompat::local_id::x()); + ThrCopy thr_copy_dot = tileloaddOt.get_slice(syclcompat::local_id::x()); + ThrCopy thr_copy_dp = tileloaddP.get_slice(syclcompat::local_id::x()); + ThrCopy thr_copy_k = tileloadK.get_slice(syclcompat::local_id::x()); + ThrCopy thr_copy_dpt = tileloaddPt.get_slice(syclcompat::local_id::x()); + ThrCopy thr_copy_qt = tileloadQt.get_slice(syclcompat::local_id::x()); + + // Retile registers for copies + Tensor tQrQ = thr_copy_q.retile_D(tSrQ); + Tensor tKtrKt = thr_copy_kt.retile_D(tSrKt); + Tensor tdOrdO = thr_copy_do.retile_D(tdPrdO); + Tensor tVrV = thr_copy_v.retile_D(tdPrV); + Tensor tPtrPt = thr_copy_pt.retile_D(tdVrPt); + Tensor tdOtrdOt = thr_copy_dot.retile_D(tdVrdOt); + Tensor tdPrdPa = thr_copy_dp.retile_D(tdQrdP); + Tensor tKrK = thr_copy_k.retile_D(tdQrK); + Tensor tdPtrdPt = thr_copy_dpt.retile_D(tdKrdPt); + Tensor tQtrQt = thr_copy_qt.retile_D(tdKrQt); + + // Retile global counting tensors for copies + Tensor tQgQ = thr_copy_q.retile_S(tSgQ); + Tensor tKtgKt = thr_copy_kt.retile_S(tSgKt); + Tensor tdOgdO = thr_copy_do.retile_S(tdPgdO); + Tensor tVgV = thr_copy_v.retile_S(tdPgV); + Tensor tPtgPt = thr_copy_pt.retile_S(tdVgPt); + Tensor tdOtgdOt = thr_copy_dot.retile_S(tdVgdOt); + Tensor tdPgdPa = thr_copy_dp.retile_S(tdQgdP); + Tensor tKgK = thr_copy_k.retile_S(tdQgK); + Tensor tdPtgdPt = thr_copy_dpt.retile_S(tdKgdPt); + Tensor tQtgQt = thr_copy_qt.retile_S(tdKgQt); + + Tensor tSrS = partition_fragment_C(tiled_mma_sdp, Shape, Int>{}); + Tensor tdPrdP = partition_fragment_C(tiled_mma_sdp, Shape, Int>{}); + Tensor tdVrdV = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); + Tensor tdQrdQ = partition_fragment_C(tiled_mma_dq, Shape, Int>{}); + Tensor tdKrdK = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); + + + // for lse read + Tensor caccS = make_identity_tensor(Shape, Int>{}); // same buffer as accS + Tensor taccScS = thr_mma_sdp.partition_C(caccS); + static_assert(decltype(size<0>(taccScS))::value == 8); + Tensor taccScS_row = logical_divide(taccScS, Shape<_1>{})(make_coord(0, _), _, 0); + Tensor lse = make_tensor(Shape>{}); + static_assert(size<0>(tSrS) * size<1>(tSrS) == size<0>(lse) && "row of acc and lse not match"); + // misc + + const int max_m_block = ceil_div(param.seq_len_q, kBlockM); + const int tail_m = param.seq_len_q % kBlockM; + + constexpr int k_tile = ceil_div(kHeadDim, kBlockK); + + cutlass::NumericConverter converter; + // clear accumulator + clear(tdVrdV); + clear(tdKrdK); + for (int m_block = 0; m_block < max_m_block; ++m_block) { + if (m_block == max_m_block - 1 and tail_m != 0) { + mQ = make_tensor(make_gmem_ptr(mQ.data()), + make_layout( + make_shape(tail_m, Int{}, _1{}), + make_stride(param.q_r_stride, _1{}, _1{}))); + mdO = make_tensor(make_gmem_ptr(mdO.data()), + make_layout( + make_shape(tail_m, Int{}, _1{}), + make_stride(param.o_r_stride, _1{}, _1{}))); + mdOt = make_tensor(make_gmem_ptr(mdOt.data()), + make_layout( + make_shape(Int{}, tail_m, _1{}), + make_stride(_1{}, param.o_r_stride, _1{}))); + mdQaccum = make_tensor(make_gmem_ptr(mdQaccum.data()), + make_layout( + make_shape(tail_m, Int{}, _1{}), + make_stride(param.dq_r_stride, _1{}, _1{}))); + mQt = make_tensor(make_gmem_ptr(mQt.data()), + make_layout( + make_shape(Int{}, tail_m, _1{}), + make_stride(_1{}, param.q_r_stride, _1{}))); + // Tensor mK = make_tensor(make_gmem_ptr(param.k_ptr + k_offset), + // make_layout( + // shapeK, + // make_stride(param.k_r_stride, _1{}, _1{}))); + + tileloadQ = typename Trait::TiledLoadQ{mQ}; + // auto tileloadK = typename Trait::TiledLoadK{mK}; + tileloaddO = typename Trait::TiledLoaddO{mdO}; + tileloaddOt = typename Trait::TiledLoaddOt{mdOt}; + tileloaddQ = typename Trait::TiledLoaddQ{mdQaccum}; + tileloadQt = typename Trait::TiledLoadQt{mQt}; + tilesavedQ = typename Trait::TiledSavedQ{mdQaccum}; + } + clear(tSrS); + clear(tdPrdP); + clear(tdQrdQ); + // S=QKt + gemm_ker(tSrS, tSrQ, tSrKt, tQgQ, tQrQ, tKtgKt, tKtrKt, tiled_mma_sdp, tileloadQ, tileloadKt, thr_copy_q, thr_copy_kt); + load_1colvec(lse, mLSE, taccScS_row); + Tensor dP_sum = make_fragment_like(lse); + load_1colvec(dP_sum, mdPsum, taccScS_row); + Tensor scores = make_tensor(tSrS.data(), convert_layout_acc_layout(tSrS.layout())); + + // P=softmax(S,lse) + scale_apply_exp2(scores, lse, param.scale_softmax_log2); + auto tSrSl = make_tensor_like(tSrS); + convert_type(converter, tSrS, tSrSl); + copy(tilesaveP, tSrSl, tPgP); // save P to internal buffers + // dP=dO*Vt + gemm_ker(tdPrdP, tdPrdO, tdPrV, tdOgdO, tdOrdO, tVgV, tVrV, tiled_mma_sdp, tileloaddO, tileloadV, thr_copy_do, thr_copy_v); + Tensor dS = make_tensor(tdPrdP.data(), scores.layout()); + // dS=P(dP-sum_row(P))*scale + softmax_backward(scores, dP_sum, dS, param.scale_softmax); + auto tdPrdPl = make_tensor_like(tdPrdP); + convert_type(converter, tdPrdP, tdPrdPl); + + // dV=Pt*dO + gemm_ker(tdVrdV, tdVrPt, tdVrdOt, tPtgPt, tPtrPt, tdOtgdOt, tdOtrdOt, tiled_mma_dkv, tileloadPt, tileloaddOt, thr_copy_pt, thr_copy_dot); + sycl::group_barrier(group); + + copy(tilesavedP, tdPrdPl, tdPgdP); // save dP to buffer after P used by dV + + if (n_block > 0) + copy(tileloaddQ, tdQgdQ, tdQrdQ); + // dQ=dP*K + gemm_ker(tdQrdQ, tdQrdP, tdQrK, tdPgdPa, tdPrdPa, tKgK, tKrK, tiled_mma_dq, tileloaddP, tileloadK, thr_copy_dp, thr_copy_k); + copy(tilesavedQ, tdQrdQ, tdQgdQ); + // dK=dPt*Q + gemm_ker(tdKrdK, tdKrdPt, tdKrQt, tdPtgdPt, tdPtrdPt, tQtgQt, tQtrQt, tiled_mma_dkv, tileloaddPt, tileloadQt, thr_copy_dpt, thr_copy_qt); + // update ptr/atom copy + mQ.data() = mQ.data() + int(kBlockM * param.q_r_stride); + mdO.data() = mdO.data() + int(kBlockM * param.o_r_stride); + mdOt.data() = mdOt.data() + int(kBlockM * param.o_r_stride); + mdQaccum.data() = mdQaccum.data() + int(kBlockM * param.dq_r_stride); + mQt.data() = mQt.data() + int(kBlockM * param.q_r_stride); + mLSE.data() = mLSE.data() + int(kBlockM); + mdPsum.data() = mdPsum.data() + int(kBlockM); + + tileloadQ = typename Trait::TiledLoadQ{mQ}; + tileloaddO = typename Trait::TiledLoaddO{mdO}; + tileloaddOt = typename Trait::TiledLoaddOt{mdOt}; + tileloaddQ = typename Trait::TiledLoaddQ{mdQaccum}; + tileloadQt = typename Trait::TiledLoadQt{mQt}; + tilesavedQ = typename Trait::TiledSavedQ{mdQaccum}; + + } + auto tdVrdVl = make_tensor_like(tdVrdV); + convert_type(converter, tdVrdV, tdVrdVl); + copy(tilesavedV, tdVrdVl, tdVgdV); + auto tdKrdKl = make_tensor_like(tdKrdK); + convert_type(converter,tdKrdK, tdKrdKl); + copy(tilesavedK, tdKrdKl, tdKgdK); +} + +template +auto convert_layout_2d_layout(Layout layout) { + auto l = make_layout(make_layout(get<0>(layout), + get<1>(layout)), + get<2>(layout)); + return l; +} + +template +void +compute_o_dot_do(T &trait, Param ¶m, + const int m_block, const int bidb, const int bidh) { + // The thread index. + constexpr int kBlockM = T::kBlockM; + constexpr int kBlockN = T::kBlockN; + constexpr int kHeadDim = T::kHeadDim; + constexpr int kNSGs = T::kNSGs; + using DType = typename T::DType; + + auto sg = syclcompat::get_nd_item<1>().get_sub_group(); + auto group = syclcompat::get_nd_item<1>().get_group(); + auto first_thread_in_sg_idx = sg.get_group_linear_id() * trait.SubgroupSize; + auto bofst = Boffset(param); + + const index_t o_offset = bofst.o_offset(bidb, bidh, m_block * kBlockM); + const index_t dq_offset = bofst.dq_offset(bidb, bidh, m_block * kBlockM); + const index_t dpsum_offset = bofst.lse_offset(bidb, bidh, m_block * kBlockM); + + using ShapeO = Shape< + std::conditional_t , int>, + Int, Int<1>>; + using ShapeP = Shape< + std::conditional_t , int>>; + ShapeO O_shape; + ShapeP dP_shape; + if constexpr(Is_even_M) { + O_shape = make_shape(Int{}, Int{}, _1{}); + dP_shape = make_shape(Int{}); + } else { + O_shape = make_shape(param.tail_m, Int{}, _1{}); + dP_shape = make_shape(param.tail_m); + } + Shape dQ_shape = make_shape(Int{}, Int{}, _1{}); + + Tensor mdO = make_tensor(make_gmem_ptr(param.do_ptr + o_offset), + make_layout( + O_shape, + make_stride(param.o_r_stride, _1{}, _1{}))); + Tensor mO = make_tensor(make_gmem_ptr(param.o_ptr + o_offset), + make_layout( + O_shape, + make_stride(param.o_r_stride, _1{}, _1{}))); + Tensor mdQaccum = make_tensor(make_gmem_ptr(param.dqaccum_ptr + dq_offset), + make_layout( + dQ_shape, + make_stride(param.dq_r_stride, _1{}))); + Tensor mdPsum = make_tensor(make_gmem_ptr(param.odo_ptr + dpsum_offset), + make_layout( + dP_shape, + Stride<_1>{})); + + Shape tile_dO = typename T::TileShapedQ{}; + auto tileloaddO = typename T::TiledLoaddP{mdO}; + auto tileloadO = typename T::TiledLoaddP{mO}; + auto tilesavedQ = typename T::TiledSavedQ{mdQaccum}; + + typename T::TiledMmadQ tiled_mma_dq; + auto thr_mma_do = tiled_mma_dq.get_slice(syclcompat::local_id::x()); + + Tensor mO_coord = cute::get_xe_tensor(O_shape); + Tensor dQ_coord = cute::get_xe_tensor(dQ_shape); + + Tensor gdO = local_tile(mO_coord, select<0, 1>(tile_dO), make_coord(0, 0, 0)); + Tensor gdQ = local_tile(dQ_coord, select<0, 1>(tile_dO), make_coord(0, 0, 0)); + + Tensor tdOgdO = thr_mma_do.partition_C(gdO); + Tensor tdQgdQ = thr_mma_do.partition_C(gdQ); + + Tensor tdQrdQ = partition_fragment_C(tiled_mma_dq, Shape, Int>{}); + Tensor tdOrdO = make_fragment_like(tdQrdQ); + Tensor tdOrO = make_fragment_like(tdQrdQ); + clear(tdOrdO); + clear(tdOrO); + clear(tdQrdQ); + copy(tilesavedQ, tdQrdQ, tdQgdQ); + copy(tileloaddO, tdOgdO, tdOrdO); + copy(tileloadO, tdOgdO, tdOrO); + + Tensor dO_reshaped = make_tensor(tdOrdO.data(), convert_layout_2d_layout(tdOrdO.layout())); + Tensor O_reshaped = make_tensor(tdOrO.data(), dO_reshaped.layout()); + constexpr int kGmemThreadsPerRow = kBlockM / size<0>(dO_reshaped); + Tensor tdOrdP = make_fragment_like(size<0>(dO_reshaped)); + constexpr int NUM_ROW_PER_THD = size<0>(dO_reshaped); // 32 + constexpr int NUM_SG_PER_ROW = kNSGs / (kBlockM / size<0>(dO_reshaped)); // 4 + constexpr int NUM_SG_PER_BLK_M = kBlockM / NUM_ROW_PER_THD; // 2 + + const int sg_local_id = sg.get_local_id(); + const int sg_group_id = sg.get_group_id(); + const int sg_group_id_M = sg_group_id % NUM_SG_PER_BLK_M; + const int sg_group_id_N = sg_group_id / NUM_SG_PER_BLK_M; + auto smem = syclcompat::local_mem(); + Tensor stensor = make_tensor(make_smem_ptr(smem), make_shape(Int{}, Int{}, Int{})); + + CUTLASS_PRAGMA_UNROLL + for (int mi = 0; mi < size<0>(dO_reshaped); ++mi) { + float dP_sum = 0.0f; + CUTLASS_PRAGMA_UNROLL + for (int ni = 0; ni < size<1>(dO_reshaped); ++ni) { + dP_sum = dP_sum + (float)dO_reshaped(mi, ni) * (float)O_reshaped(mi, ni); + } + tdOrdP(mi) = dP_sum; + } + /* + * reduce within subgroup + */ + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(tdOrdP); ++i) { + tdOrdP(i) = reduce_over_group(sg, tdOrdP(i), sycl::plus<>()); + } + /* + * store to smem + */ + if (sg_local_id == 0) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(tdOrdP); ++i) { + stensor(i, sg_group_id_N, sg_group_id_M) = tdOrdP(i); + } + } + + sycl::group_barrier(group); + /* + * reduce all sgs in the same row + */ + if (sg_local_id == 0 and sg_group_id_N == 0) { + for (int i = 0; i < size<0>(tdOrdP); ++i) { + tdOrdP(i) = 0.0f; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < NUM_SG_PER_ROW; ++j) { + tdOrdP(i) += stensor(i, j, sg_group_id_M); + } + } + } + // /* + // * broadcast to all threads in the sg 0 + // */ + // CUTLASS_PRAGMA_UNROLL + // for (int i = 0; i < size<0>(tdOrdP); ++i) { + // tdOrdP(i) = sycl::group_broadcast(sg, tdOrdP(i), 0); + // } + /* + * write back to global memory + */ + if constexpr(Is_even_M) { + if (sg_local_id == 0 and sg_group_id_N == 0) { + const int offset = sg_group_id_M * NUM_ROW_PER_THD; + for (int i = 0; i < size<0>(tdOrdP); ++i) { + mdPsum(i + offset) = tdOrdP(i); + } + } + } else { + if (sg_local_id == 0 and sg_group_id_N == 0) { + const int offset = sg_group_id_M * NUM_ROW_PER_THD; + int j = offset; + for (int i = 0; i < size<0>(tdOrdP) and j < param.tail_m; ++i,++j) { + mdPsum(j) = tdOrdP(i); + } + } + } +} + +template +void +mha_backward(T trait, + Param param) { + const int bidb = BlockIdxZ(); + const int bidh = BlockIdxY(); + constexpr bool parallel_seq_kv = false; + // const int max_n_block = ceil_div(param.seq_len_kv, trait.kBlockN); + for (int n_block = 0; n_block < param.n_block; ++n_block) + dq_dk_dv_1colblock(trait, param, bidb, bidh, n_block); + if (param.tail_n > 0) + dq_dk_dv_1colblock(trait, param, bidb, bidh, param.n_block, param.tail_n); +} + +template +void +mha_dot_do_o(T trait, + Param param) { + // The block index for the M dimension. + const int m_block = BlockIdxX(); + // The block index for the batch. + const int bidb = BlockIdxZ(); + // The block index for the head. + const int bidh = BlockIdxY();; + if (m_block == param.m_block - 1 and param.tail_m > 0) { + compute_o_dot_do(trait, param, m_block, bidb, bidh); + } else { + compute_o_dot_do(trait, param, m_block, bidb, bidh); + } +} + +template +void +convert_dq(T &trait, Param ¶m, int m_block, int bidb, int bidh) { + constexpr int kBlockM = T::kBlockM; + constexpr int kBlockN = T::kBlockN; + constexpr int kHeadDim = T::kHeadDim; + constexpr int kNSGs = T::kNSGs; + using DType = typename T::DType; + using VType = typename T::VType; + + auto bofst = Boffset(param); + const index_t dq_offset = bofst.dq_offset(bidb, bidh, m_block * kBlockM); + const index_t q_offset = bofst.q_offset(bidb, bidh, m_block * kBlockM); + VType * dQaccum = param.dqaccum_ptr + dq_offset; + DType * dQ = param.dq_ptr + q_offset; + + int tail_m = param.seq_len_q - m_block * kBlockM; + int m = ThreadIdxX(); + if (m < tail_m) { + for (int h = 0; h < kHeadDim; ++h) { + dQ[m * param.q_r_stride + h] = static_cast(dQaccum[m * param.dq_r_stride + h]); + } + } +} + +template +void +convert_dq(T &trait, Param ¶m, int m_block, int bidb, int bidh) { + constexpr int kBlockM = T::kBlockM; + constexpr int kBlockN = T::kBlockN; + constexpr int kHeadDim = T::kHeadDim; + using DType = typename T::DType; + using VType = typename T::VType; + auto sg = syclcompat::get_nd_item<1>().get_sub_group(); + auto first_thread_in_sg_idx = sg.get_group_linear_id() * trait.SubgroupSize; + + auto bofst = Boffset(param); + const index_t dq_offset = bofst.dq_offset(bidb, bidh, m_block * kBlockM); + const index_t q_offset = bofst.q_offset(bidb, bidh, m_block * kBlockM); + using ShapeQ = Shape< + std::conditional_t, int>, + Int, _1>; + ShapeQ shapeQ; + if constexpr (Is_even_M) { + shapeQ = make_shape(Int{}, Int{}, _1{}); + } else { + shapeQ = make_shape(param.tail_m, Int{}, _1{}); + } + + Tensor mdQaccum = make_tensor(make_gmem_ptr(param.dqaccum_ptr + dq_offset), + make_layout( + shapeQ, + make_stride(param.dq_r_stride, _1{}, _1{}))); + Tensor mdQ = make_tensor(make_gmem_ptr(param.dq_ptr + q_offset), + make_layout( + shapeQ, + make_stride(param.q_r_stride, _1{}, _1{}))); + + Shape tile_dq = typename T::TileShapedQ{}; + + auto tileloaddQ = typename T::TiledLoaddQ{mdQaccum}; + auto tilesavedQ = typename T::TiledSavedV{mdQ}; + + + typename T::TiledMmadQ tiled_mma_dq; + auto thr_mma_dq = tiled_mma_dq.get_slice(first_thread_in_sg_idx); + + Tensor mQ_coord = cute::get_xe_tensor(shapeQ); + Tensor gdQ = local_tile(mQ_coord, select<0, 1>(tile_dq), make_coord(0, 0, 0)); // dump dQ + + Tensor tdQgdQ = thr_mma_dq.partition_C(gdQ); // save to dq + Tensor tdQrdQaccum = partition_fragment_C(tiled_mma_dq, Shape, Int>{}); + + Tensor tdQrdQ = make_fragment_like(tdQrdQaccum); + copy(tileloaddQ, tdQgdQ, tdQrdQaccum); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tdQrdQ); ++i) { + tdQrdQ(i) = static_cast(tdQrdQaccum(i)); + } + copy(tilesavedQ, tdQrdQ, tdQgdQ); +} + +template +void +mhd_convert_dq(T trait, + Param param) { + // The block index for the M dimension. + const int m_block = BlockIdxX(); + // The block index for the batch. + const int bidb = BlockIdxZ(); + // The block index for the head. + const int bidh = BlockIdxY(); + if (param.tail_m > 0 and m_block == param.m_block - 1) { + convert_dq(trait, param, m_block, bidb, bidh); + } else { + convert_dq(trait, param, m_block, bidb, bidh); + } +} + +template +void launch_mha_backward_headdim(ProblemShape problem_shape, + const T *do_d, + const T *o_d, + const T *q_d, + const T *k_d, + const T *v_d, + const float *lse_d, + float *odo_d, + float *dqaccum_d, + T *dq_d, + T *dk_d, + T *dv_d, + T *s_d, + T *dp_d, + const int seq_len_q_pad, + const int seq_len_kv_pad) { + constexpr int numSGs = 8; + constexpr int kBlockK = 32; + auto trait = FAKernel{}; + + const int BATCH = get<0>(problem_shape); + const int NUM_HEAD_Q = get<1>(problem_shape); + const int NUM_HEAD_KV = get<2>(problem_shape); + const int SEQ_LEN_Q = get<3>(problem_shape); + const int SEQ_LEN_KV = get<4>(problem_shape); + const int N_BLOCK = SEQ_LEN_KV / kBlockN; + const int tail_n = SEQ_LEN_KV % kBlockN; + const int M_BLOCK = ceil_div(SEQ_LEN_Q, kBlockM); + const int tail_m = SEQ_LEN_Q % kBlockM; + T * pbuff = syclcompat::malloc(BATCH * NUM_HEAD_Q * seq_len_kv_pad * kBlockM); + auto param = Param(do_d, o_d, q_d, k_d, v_d, lse_d, odo_d, + dqaccum_d, dq_d, dk_d, dv_d, s_d, dp_d, pbuff, + 1 / sqrt(static_cast(kHeadDim))); + param.batch = BATCH; + param.num_head_q = NUM_HEAD_Q; + param.num_head_kv = NUM_HEAD_KV; + param.seq_len_q = SEQ_LEN_Q; + param.seq_len_kv = SEQ_LEN_KV; + param.head_dim = kHeadDim; + param.n_block = N_BLOCK; + param.tail_n = tail_n; + param.m_block = M_BLOCK; + param.tail_m = tail_m; + param.seq_len_kv_pad = seq_len_kv_pad; + param.seq_len_q_pad = seq_len_q_pad; + if constexpr(is_bhsd) { + setup_bhsd_stride(param); + } else { + setup_bshd_stride(param); + } + + auto dimGrid0 = syclcompat::dim3(size(M_BLOCK), size(param.num_head_q), size(param.batch)); + auto dimBlock0 = syclcompat::dim3(size(numSGs * trait.SubgroupSize), size(1), size(1)); + syclcompat::experimental::launch_properties launch_props0{ + // sycl::ext::oneapi::experimental::work_group_scratch_size(0), + }; + syclcompat::experimental::kernel_properties kernel_props0{ + sycl::ext::oneapi::experimental::sub_group_size}; + syclcompat::experimental::launch_policy policy0{dimGrid0, dimBlock0, launch_props0, kernel_props0}; + auto event0 = syclcompat::experimental::launch< + mha_dot_do_o>(policy0, + trait, + param); + EventManager::getInstance().addEvent(event0); + + auto dimGrid1 = syclcompat::dim3(size(1), size(param.num_head_q), size(param.batch)); + assert((trait.num_head_q % trait.num_head_kv == 0) && "num_head_q must be dividable by num_head_kv"); + assert((trait.num_head_q >= trait.num_head_kv) && "num_head_q must be bigger than or equal to num_head_kv"); + auto dimBlock1 = syclcompat::dim3(size(numSGs * trait.SubgroupSize), size(1), size(1)); + // auto dimBlock = syclcompat::dim3(size(trait.tiled_mma_sdp)); + + syclcompat::experimental::launch_properties launch_props1{ + // sycl::ext::oneapi::experimental::work_group_scratch_size(0), + }; + syclcompat::experimental::kernel_properties kernel_props1{ + sycl::ext::oneapi::experimental::sub_group_size}; + syclcompat::experimental::launch_policy policy1{dimGrid1, dimBlock1, launch_props1, kernel_props1}; + auto event1 = syclcompat::experimental::launch< + mha_backward>(policy1, + trait, + param); + EventManager::getInstance().addEvent(event1); + + auto dimGrid2 = syclcompat::dim3(size(M_BLOCK), size(param.num_head_q), size(param.batch)); + auto dimBlock2 = syclcompat::dim3(size(numSGs * trait.SubgroupSize), size(1), size(1)); + syclcompat::experimental::launch_properties launch_props2{ + // sycl::ext::oneapi::experimental::work_group_scratch_size(0), + }; + syclcompat::experimental::kernel_properties kernel_props2{ + sycl::ext::oneapi::experimental::sub_group_size}; + syclcompat::experimental::launch_policy policy2{dimGrid2, dimBlock2, launch_props2, kernel_props2}; + auto event2 = syclcompat::experimental::launch< + mhd_convert_dq>(policy2, + trait, + param); + EventManager::getInstance().addEvent(event2); +} + +template +void launch_mha_backward(ProblemShape problem_shape, + const T *do_d, + const T *o_d, + const T *q_d, + const T *k_d, + const T *v_d, + const float *lse_d, + float *odo_d, + float *dqaccum_d, + T *dq_d, + T *dk_d, + T *dv_d, + T *s_d, + T *dp_d, + const int seq_len_q_pad, + const int seq_len_kv_pad) { + const int headdim = get<5>(problem_shape); + if (headdim == 64) { + constexpr int kHeadDim = 64; + launch_mha_backward_headdim( + problem_shape, + do_d, o_d, q_d, k_d, v_d, + lse_d, odo_d, + dqaccum_d, dq_d, dk_d, dv_d, + s_d, dp_d, + seq_len_q_pad, seq_len_kv_pad); + } else if (headdim == 96) { + constexpr int kHeadDim = 96; + launch_mha_backward_headdim( + problem_shape, + do_d, o_d, q_d, k_d, v_d, + lse_d, odo_d, + dqaccum_d, dq_d, dk_d, dv_d, + s_d, dp_d, + seq_len_q_pad, seq_len_kv_pad); + } else if (headdim == 128) { + constexpr int kHeadDim = 128; + launch_mha_backward_headdim( + problem_shape, + do_d, o_d, q_d, k_d, v_d, + lse_d, odo_d, + dqaccum_d, dq_d, dk_d, dv_d, + s_d, dp_d, + seq_len_q_pad, seq_len_kv_pad); + } else if (headdim == 192) { + constexpr int kHeadDim = 192; + launch_mha_backward_headdim( + problem_shape, + do_d, o_d, q_d, k_d, v_d, + lse_d, odo_d, + dqaccum_d, dq_d, dk_d, dv_d, + s_d, dp_d, + seq_len_q_pad, seq_len_kv_pad); + } else if (headdim == 256) { + constexpr int kHeadDim = 256; + launch_mha_backward_headdim( + problem_shape, + do_d, o_d, q_d, k_d, v_d, + lse_d, odo_d, + dqaccum_d, dq_d, dk_d, dv_d, + s_d, dp_d, + seq_len_q_pad, seq_len_kv_pad); + } else { + assert(false && "only support headdim 64,96,128,192,256"); + } +} + +int main(int argc, char**argv) { + // using T = cute::bfloat16_t; + using T = cute::half_t; + using V = float; + std::string data_file = "mha.npz"; + // read qkv + cnpy::NpyArray q_npy = cnpy::npz_load(data_file, "q"); + cnpy::NpyArray k_npy = cnpy::npz_load(data_file, "k"); + cnpy::NpyArray v_npy = cnpy::npz_load(data_file, "v"); + + // read grad output + cnpy::NpyArray do_npy = cnpy::npz_load(data_file, "grad"); + cnpy::NpyArray o_npy = cnpy::npz_load(data_file, "out"); + + // read lse + cnpy::NpyArray lse_npy = cnpy::npz_load(data_file, "lse"); + // read odo + cnpy::NpyArray odo_npy = cnpy::npz_load(data_file, "odo"); + + // read grad reference + cnpy::NpyArray dq_npy = cnpy::npz_load(data_file, "q_grad"); + cnpy::NpyArray dk_npy = cnpy::npz_load(data_file, "k_grad"); + cnpy::NpyArray dv_npy = cnpy::npz_load(data_file, "v_grad"); + cnpy::NpyArray dp_npy = cnpy::npz_load(data_file, "p_grad"); + cnpy::NpyArray ds_npy = cnpy::npz_load(data_file, "s_grad"); + + // read shape + cnpy::NpyArray shape = cnpy::npz_load(data_file, "shape"); + + int64_t BATCH = shape.data()[0]; + int64_t NUM_HEAD_Q = shape.data()[1]; + int64_t NUM_HEAD_KV = shape.data()[2]; + int64_t SEQ_LEN_QO = shape.data()[3]; + int64_t SEQ_LEN_KV = shape.data()[4]; + int64_t HEAD_SIZE_QK = shape.data()[5]; + int64_t HEAD_SIZE_VO = shape.data()[6]; + bool is_causal = shape.data()[7]; + bool is_bhsd = shape.data()[8]; + assert(HEAD_SIZE_QK == HEAD_SIZE_VO && "only support head_size_qk==head_size_vo"); + constexpr int kBlockN = 32; + constexpr int kBlockM = 64; + int64_t SEQ_LEN_QO_PAD = ceil_div(SEQ_LEN_QO, kBlockM) * kBlockM; + int64_t SEQ_LEN_KV_PAD = ceil_div(SEQ_LEN_KV, kBlockN) * kBlockN; + printf("batch %d nh_q %d nh_k %d sq_q %d(%d) sq_k %d(%d) hd_q %d hd_v %d causal %d bhsd %d\n", BATCH, NUM_HEAD_Q, NUM_HEAD_KV, SEQ_LEN_QO, SEQ_LEN_QO_PAD, SEQ_LEN_KV, SEQ_LEN_KV_PAD, HEAD_SIZE_QK, HEAD_SIZE_VO, is_causal, is_bhsd); + // read_args(argc, argv, 1, &BATCH); + // read_args(argc, argv, 2, &NUM_HEAD_Q); + // read_args(argc, argv, 3, &NUM_HEAD_KV); + // read_args(argc, argv, 4, &SEQ_LEN_QO); + // read_args(argc, argv, 5, &SEQ_LEN_KV); + // read_args(argc, argv, 6, &HEAD_SIZE_QK); + // read_args(argc, argv, 7, &HEAD_SIZE_VO); + + // alloc qkv + T *q_d = syclcompat::malloc(q_npy.num_vals); + T *k_d = syclcompat::malloc(k_npy.num_vals); + T *v_d = syclcompat::malloc(v_npy.num_vals); + + // alloc ps + T *p_d = syclcompat::malloc(BATCH * NUM_HEAD_Q * SEQ_LEN_QO_PAD * SEQ_LEN_KV_PAD); + T *s_d = syclcompat::malloc(BATCH * NUM_HEAD_Q * SEQ_LEN_QO_PAD * SEQ_LEN_KV_PAD); + + // alloc lse, odo + V *lse_d = syclcompat::malloc(lse_npy.num_vals); + V *odo_d = syclcompat::malloc(odo_npy.num_vals); + + // alloc grad output + T *do_d = syclcompat::malloc(do_npy.num_vals); + T *o_d = syclcompat::malloc(o_npy.num_vals); + + // alloc grad test on device + T *dq_d = syclcompat::malloc(dq_npy.num_vals); + V *dqaccum_d = syclcompat::malloc(BATCH * NUM_HEAD_Q * SEQ_LEN_QO_PAD * HEAD_SIZE_QK); + T *dk_d = syclcompat::malloc(dk_npy.num_vals); + T *dv_d = syclcompat::malloc(dv_npy.num_vals); + T *dp_d = syclcompat::malloc(BATCH * NUM_HEAD_Q * SEQ_LEN_QO_PAD * SEQ_LEN_KV_PAD); + // copy qkv + syclcompat::memcpy(q_d, q_npy.data(), q_npy.num_vals); + syclcompat::memcpy(k_d, k_npy.data(), k_npy.num_vals); + syclcompat::memcpy(v_d, v_npy.data(), v_npy.num_vals); + + // copy grad output + syclcompat::memcpy(do_d, do_npy.data(), do_npy.num_vals); + syclcompat::memcpy(o_d, o_npy.data(), o_npy.num_vals); + + // copy lse + syclcompat::memcpy(lse_d, lse_npy.data(), lse_npy.num_vals); + + // copy odo + // syclcompat::memcpy(odo_d, odo_npy.data(), odo_npy.num_vals); + + auto problem_shape = ProblemShapeRegular(BATCH, NUM_HEAD_Q, NUM_HEAD_KV, SEQ_LEN_QO, SEQ_LEN_KV, HEAD_SIZE_QK, HEAD_SIZE_VO); + if (is_bhsd) { + launch_mha_backward( + problem_shape, + do_d, o_d, + q_d, k_d, v_d, + lse_d, odo_d, + dqaccum_d, dq_d, dk_d, dv_d, + s_d, dp_d, SEQ_LEN_QO_PAD, SEQ_LEN_KV_PAD); + } else { + launch_mha_backward( + problem_shape, + do_d, o_d, + q_d, k_d, v_d, + lse_d, odo_d, + dqaccum_d, dq_d, dk_d, dv_d, + s_d, dp_d, SEQ_LEN_QO_PAD, SEQ_LEN_KV_PAD); + } + float atol = 1e-3f; + float rtol = 1e-3f; + + std::vector odo_test(odo_npy.num_vals); + syclcompat::memcpy(odo_test.data(), odo_d, odo_test.size()); + syclcompat::wait_and_throw(); + printf("odo val: "); + verify(odo_npy.data(), odo_test.data(), BATCH, NUM_HEAD_Q, SEQ_LEN_QO, atol, rtol); + + syclcompat::wait_and_throw(); + std::vector dv_test(BATCH * NUM_HEAD_KV * SEQ_LEN_KV * HEAD_SIZE_VO); + syclcompat::memcpy(dv_test.data(), dv_d, dv_test.size()); + syclcompat::wait_and_throw(); + printf("dV val: "); + verify(dv_npy.data(), dv_test.data(), BATCH * NUM_HEAD_KV, SEQ_LEN_KV, HEAD_SIZE_VO, atol, rtol); + + std::vector dk_test(BATCH * NUM_HEAD_KV * SEQ_LEN_KV * HEAD_SIZE_QK); + syclcompat::memcpy(dk_test.data(), dk_d, dk_test.size()); + syclcompat::wait_and_throw(); + printf("dK val: "); + verify(dk_npy.data(), dk_test.data(), BATCH * NUM_HEAD_KV, SEQ_LEN_KV, HEAD_SIZE_QK, atol, rtol); + + std::vector dq_test(BATCH * NUM_HEAD_Q * SEQ_LEN_QO * HEAD_SIZE_QK); + syclcompat::memcpy(dq_test.data(), dq_d, dq_test.size()); + syclcompat::wait_and_throw(); + printf("dQ val: "); + verify(dq_npy.data(), dq_test.data(), BATCH * NUM_HEAD_Q, SEQ_LEN_QO, HEAD_SIZE_QK, atol, rtol); +} diff --git a/examples/sycl/sdpa_bwd/sdpa_util.hpp b/examples/sycl/sdpa_bwd/sdpa_util.hpp new file mode 100644 index 0000000000..47959c9370 --- /dev/null +++ b/examples/sycl/sdpa_bwd/sdpa_util.hpp @@ -0,0 +1,231 @@ +#pragma once +#include +#include +#include +#include +#include +#include + +bool isclose(float a, float b, float atol, float rtol) { + return std::abs(a - b) <= atol + rtol * std::abs(b); +} + +template +float +cosinesimilarity(T *refe, V *test, size_t m) { + float ab = 0.0f; + float a2 = 0.0f; + float b2 = 0.0f; + for (size_t i = 0; i < m; ++i) { + float t_f = (float)test[i]; + float r_f = (float)refe[i]; + ab += t_f * r_f; + a2 += t_f * t_f; + b2 += r_f * r_f; + } + float factor = ab / sqrtf(a2 * b2); + // printf("f=%f\n", factor); + return factor; +} + +template +float +cosinesimilarity(T *refe, V *test, size_t L, size_t M, size_t M_PAD, size_t N, size_t N_PAD) { + float ab = 0.0f; + float a2 = 0.0f; + float b2 = 0.0f; + for (size_t l = 0; l < L; ++l) { + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + size_t i = l * M * N + m * N + n; + size_t j = l * M_PAD * N_PAD + m * N_PAD + n; + float r_f = (float)refe[i]; + float t_f = (float)test[j]; + ab += t_f * r_f; + a2 += t_f * t_f; + b2 += r_f * r_f; + } + } + } + float factor = ab / sqrtf(a2 * b2); + // printf("f=%f\n", factor); + return factor; +} + +template +float +cosinesimilarity(T *refe, V *test, size_t B, size_t H, size_t S, size_t S_PAD, size_t D) { + float ab = 0.0f; + float a2 = 0.0f; + float b2 = 0.0f; + if (is_bhsd) { + for (int b = 0; b < B; ++b) { + for (int h = 0; h < H; ++h) { + for (int s = 0; s < S; ++s) { + for (int d = 0; d < D; ++d) { + int i = b * H * S * D + h * S * D + s * D + d; + int j = b * H * S_PAD * D + h * S_PAD * D + s * D + d; + float r_f = (float)refe[i]; + float t_f = (float)test[j]; + ab += t_f * r_f; + a2 += t_f * t_f; + b2 += r_f * r_f; + } + } + } + } + } else { + for (int b = 0; b < B; ++b) { + for (int s = 0; s < S; ++s) { + for (int h = 0; h < H; ++h) { + for (int d = 0; d < D; ++d) { + int i = b * S * H * D + s * H * D + h * D + d; + int j = b * S_PAD * H * D + s * H * D + h * D + d; + float r_f = (float)refe[i]; + float t_f = (float)test[j]; + ab += t_f * r_f; + a2 += t_f * t_f; + b2 += r_f * r_f; + } + } + } + } + } + float factor = ab / sqrtf(a2 * b2); + // printf("f=%f\n", factor); + return factor; +} + +template +bool allclose(T *refe, V *test, int L, int M, int N, float atol, float rtol) { + size_t err = 0; + size_t count = L * M * N; + bool flag = true; + for (int l = 0; l < L; ++l) { + for (int m = 0; m < M; ++m) { + for (int n = 0; n < N; ++n) { + float expect = (float)refe[l * M * N + m * N + n]; + float value = (float)test[l * M * N + m * N + n]; + // printf("(%d, %d, %d) expect: %f value: %f ratio %f\n", l, m, n, expect, value, value / expect); + if (not isclose(expect, value, atol, rtol)) { + printf("(%d, %d, %d) expect: %f value: %f ratio %f\n", l, m, n, expect, value, value / expect); + err++; + } + } + } + } + float ratio = static_cast(count - err) / static_cast(count); + // printf("c=%f (%ld)\n", ratio, err); + // printf("CHECK SUM SUCCESS\n"); + return ratio > 0.99f; +} + +template +bool allclose(T *refe, V *test, int L, int M, int M_PAD, int N, int N_PAD, float atol, float rtol) { + size_t err = 0; + size_t count = L * M * N; + bool flag = true; + for (int l = 0; l < L; ++l) { + for (int m = 0; m < M; ++m) { + for (int n = 0; n < N; ++n) { + int i = l * M * N + m * N + n; + int j = l * M_PAD * N_PAD + m * N_PAD + n; + float expect = (float)refe[i]; + float value = (float)test[j]; + // printf("(%d, %d, %d) expect: %f value: %f ratio %f\n", l, m, n, expect, value, value / expect); + if (not isclose(expect, value, atol, rtol)) { + printf("(%d, %d, %d) expect: %f value: %f ratio %f\n", l, m, n, expect, value, value / expect); + err++; + } + } + } + } + float ratio = static_cast(count - err) / static_cast(count); + // printf("c=%f (%ld)\n", ratio, err); + // printf("CHECK SUM SUCCESS\n"); + return ratio > 0.99f; +} + +template +bool allclose(T *refe, V *test, int B, int H, int S, int S_PAD, int D, float atol, float rtol) { + size_t err = 0; + size_t count = B * S * H * D; + bool flag = true; + if (is_bhsd) { + for (int b = 0; b < B; ++b) { + for (int h = 0; h < H; ++h) { + for (int s = 0; s < S; ++s) { + for (int d = 0; d < D; ++d) { + int i = b * H * S * D + h * S * D + s * D + d; + int j = b * H * S_PAD * D + h * S_PAD * D + s * D + d; + float expect = (float)refe[i]; + float value = (float)test[j]; + if (not isclose(expect, value, atol, rtol)) { + printf("(%d, %d, %d, %d) expect: %f value: %f ratio %f\n", b, h, s, d, expect, value, value / expect); + err++; + } + } + } + } + } + } else { + for (int b = 0; b < B; ++b) { + for (int s = 0; s < S; ++s) { + for (int h = 0; h < H; ++h) { + for (int d = 0; d < D; ++d) { + int i = b * S * H * D + s * H * D + h * D + d; + int j = b * S_PAD * H * D + s * H * D + h * D + d; + float expect = (float)refe[i]; + float value = (float)test[j]; + if (not isclose(expect, value, atol, rtol)) { + printf("(%d, %d, %d, %d) expect: %f value: %f ratio %f\n", b, s, h, d, expect, value, value / expect); + err++; + } + } + } + } + } + } + float ratio = static_cast(count - err) / static_cast(count); + // printf("c=%f (%ld)\n", ratio, err); + // printf("CHECK SUM SUCCESS\n"); + return ratio > 0.99f; +} + +static constexpr char strSUCCESS[] = "\x1B[32mPASS\x1B[0m"; +static constexpr char strFAILURE[] = "\x1B[31mFAIL\x1B[0m"; +template +void verify(T *refe, V *test, int l, int m, int n, float atol, float rtol) { + bool close = allclose(refe, test, l, m, n, atol, rtol); + bool cosine = cosinesimilarity(refe, test, l * m * n) > 0.99f; + printf("%s allclose %s cosinesim %s\n", (close and cosine) ? strSUCCESS : strFAILURE, close ? strSUCCESS : strFAILURE, cosine ? strSUCCESS : strFAILURE); +} + +template +void verify(T *refe, V *test, int l, int m, int m_pad, int n, int n_pad, float atol, float rtol) { + bool close = allclose(refe, test, l, m, m_pad, n, n_pad, atol, rtol); + bool cosine = cosinesimilarity(refe, test, l, m, m_pad, n, n_pad) > 0.99f; + printf("%s allclose %s cosinesim %s\n", (close and cosine) ? strSUCCESS : strFAILURE, close ? strSUCCESS : strFAILURE, cosine ? strSUCCESS : strFAILURE); +} + +template +void verify(T *refe, V *test, int b, int h, int s, int s_pad, int d, float atol, float rtol) { + bool close = allclose(refe, test, b, h, s, s_pad, d, atol, rtol); + bool cosine = cosinesimilarity(refe, test, b, h, s, s_pad, d) > 0.99f; + printf("%s allclose %s cosinesim %s\n", (close and cosine) ? strSUCCESS : strFAILURE, close ? strSUCCESS : strFAILURE, cosine ? strSUCCESS : strFAILURE); +} + +template +void read_file(T *ptr, std::string filename, size_t rsize) { + std::ifstream file(filename, std::ios::in | std::ios::binary | std::ios::ate); + if (file.is_open()) { + size_t fsize = file.tellg(); + assert(fsize == rsize); + size_t len = fsize / sizeof(T); + file.seekg(0, std::ios::beg); + file.read((char *)ptr, len * sizeof(T)); + file.close(); + } else { + std::cout << "fail to open " << filename << std::endl; + } +} diff --git a/examples/sycl/sdpa_bwd/test_sdpa_s.py b/examples/sycl/sdpa_bwd/test_sdpa_s.py new file mode 100644 index 0000000000..8b8139eb69 --- /dev/null +++ b/examples/sycl/sdpa_bwd/test_sdpa_s.py @@ -0,0 +1,356 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.attention import SDPBackend, sdpa_kernel +import numpy as np +def is_close(refe: torch.Tensor, + test: torch.Tensor): + test = test.to(torch.float32) + refe = refe.to(torch.float32) + cosfactor = F.cosine_similarity(test.reshape(-1), refe.reshape(-1), dim=0) > 0.99 + allclose = torch.allclose(test, refe, atol=3e-3, rtol=3e-3) + return cosfactor and allclose + +def num_head_bcast(query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor): + q_num_heads = query.size(-3) + k_num_heads = key.size(-3) + v_num_heads = value.size(-3) + k_dim0 = key.size(0) + k_dim1 = key.size(1) + k_dim2 = key.size(2) + k_dim3 = key.size(3) + v_dim0 = value.size(0) + v_dim1 = value.size(1) + v_dim2 = value.size(2) + v_dim3 = value.size(3) + if (q_num_heads == k_num_heads) and (q_num_heads == v_num_heads): + return key, value + k_repeat = q_num_heads // k_num_heads + v_repeat = q_num_heads // v_num_heads + key = key.repeat_interleave(k_repeat, 1).reshape(k_dim0, k_repeat * k_dim1, k_dim2, k_dim3) + value = value.repeat_interleave(v_repeat, 1).reshape(v_dim0, v_repeat * v_dim1, v_dim2, v_dim3) + return key, value + +def num_head_reduce(expand_grad: torch.Tensor, + x: torch.Tensor): + num_heads_expand = expand_grad.size(-3) + num_heads_orig = x.size(-3) + if (num_heads_expand == num_heads_orig): + return expand_grad + n_repeat = num_heads_expand // num_heads_orig + assert len(x.shape) == 4 + batch, num_head, seq_len, head_size = x.size() + expand_grad = expand_grad.reshape(batch, num_head, n_repeat, seq_len, head_size) + grad = torch.sum(expand_grad, dim=2).reshape(batch, num_head, seq_len, head_size) + return grad + +GRAD_DICT = {} + +def dump_grad(name, value): + global GRAD_DICT + if name not in GRAD_DICT: + GRAD_DICT[name] = value.clone() + else: + print(f'duplicated grad {name}') + return + +def softmax_backward(y: torch.Tensor, + grad_y: torch.Tensor, + scale: float): + orig_dtype = y.dtype + rest_dim = y.shape[:-1] + dim = y.shape[-1] + y = y.to(torch.float32) + grad_y = grad_y.to(torch.float32) + ydy = grad_y * y + sum_row = torch.sum(ydy, dim= -1).reshape(*rest_dim, 1) + grad_x2 = ydy - y * sum_row + grad_x = grad_x2.reshape(*rest_dim, dim) * scale + return grad_x.to(orig_dtype) + +def softmax_backward_odo(p: torch.Tensor, + dp: torch.Tensor, + o: torch.Tensor, + do: torch.Tensor, + scale: float): + orig_dtype = p.dtype + o = o.to(torch.float32) + do = do.to(torch.float32) + p = p.to(torch.float32) + dp = dp.to(torch.float32) + odo = o * do + sum_odo = torch.sum(odo, dim= -1, keepdim=True) + ds = p * (dp - sum_odo) * scale + ds = ds.to(orig_dtype) + return ds, sum_odo + +def dropout_backward(mask: torch.Tensor, + grad_y: torch.Tensor, + dropout_p: float): + return mask * grad_y / (1 - dropout_p) + +def dropout_backward2(grad_y: torch.Tensor, + dropout_p: float): + return dropout_backward(mask, grad_y, dropout_p) + +def dropout_forward(seed: int, + dropout_p: float, + x: torch.Tensor): + torch.manual_seed(seed) + mask = torch.empty_like(x).fill_(dropout_p) + prob = torch.bernoulli(mask).logical_not() + y = x * prob / (1 - dropout_p) + return y + +def softmax_causal_backward(y1: torch.Tensor, + y2: torch.Tensor, + grad_y: torch.Tensor): + # y1 attn2 after dropout mask + # y2 attn after softmax, only half mask + orig_dtype = y1.dtype + rest_dim = y1.shape[:-1] + dim = y1.shape[-1] + # seq_len_q = y.size()[-2] + # seq_len_k = y.size()[-1] + # seq_len_q = grad_y.size()[-2] + # seq_len_k = grad_y.size()[-1] + # mask = torch.ones(seq_len_q, seq_len_k, dtype=torch.bool).tril(diagonal=0) + y1 = y1.to(torch.float32) + grad_y = grad_y.to(torch.float32) + grad_y = grad_y + ydy = grad_y * y1 + sum_row = torch.sum(ydy, dim= -1).reshape(*rest_dim, 1) + grad_x2 = ydy - y2 * sum_row + grad_x = grad_x2.reshape(*rest_dim, dim) + return grad_x.to(orig_dtype) + +class SDPA(nn.Module): + def __init__(self, dropout_p) -> None: + super().__init__() + if dropout_p > 0.0: + self.do_m = nn.Dropout(p=dropout_p) + + def forward(self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: float = None): + dtype = q.dtype + self.head_size_q = q.size()[-1] + self.q = q.clone() + self.k = k.clone() + self.v = v.clone() + seq_len_q, seq_len_k = q.size(-2), k.size(-2) + + attn_bias = torch.zeros(seq_len_q, seq_len_k, dtype=dtype) + self.is_causal = is_causal + if is_causal: + temp_mask = torch.ones(seq_len_q, seq_len_k, dtype=torch.bool).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(dtype) + + k_expand, v_expand = num_head_bcast(q, k, v) + #k_expand = k + #v_expand = v + # k_expand.register_hook(lambda x: dump_grad('k_expand_grad', k_expand)) + # v_expand.register_hook(lambda x: dump_grad('v_expand_grad', v_expand)) + #self.p = q@k_expand.transpose(-2, -1) + s = torch.matmul(q, k_expand.transpose(-1, -2)) + s.register_hook(lambda x: dump_grad('s_grad', x)) + self.s = s + s = s.to(torch.float32) + self.softmax_scale = 1 / np.sqrt(self.head_size_q) if scale is None else scale + s = s * self.softmax_scale + attn_bias + sum_row, _ = torch.max(s, dim= -1, keepdim=True) + s = s - sum_row + self.lse = torch.logsumexp(s, dim=-1, keepdim=True) + sum_row + p = torch.softmax(s, dim= -1).to(dtype) + self.p = p + p.register_hook(lambda x: dump_grad('p_grad', x)) + attn = torch.matmul(p, v_expand) + attn.register_hook(lambda x: dump_grad('O_grad', x)) + self.o = attn + return attn + + def backward_ref(self, + o_grad: torch.Tensor): + q_grad = torch.empty_like(self.q) + k_grad = torch.empty_like(self.k) + v_grad = torch.empty_like(self.v) + k_expand, v_expand = num_head_bcast(self.q, self.k, self.v) + # forward + s = torch.matmul(self.q, k_expand.transpose(-1, -2)) + s = s.to(torch.float32) + dtype = self.q.dtype + seq_len_q, seq_len_k = q_grad.size(-2), k_grad.size(-2) + attn_bias = torch.zeros(seq_len_q, seq_len_k, dtype=dtype) + if self.is_causal: + temp_mask = torch.ones(seq_len_q, seq_len_k, dtype=torch.bool).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(dtype) + s = s * self.softmax_scale + attn_bias + p = torch.exp(s - self.lse).to(dtype) + # backward + v_grad = torch.matmul(p.transpose(-1, -2), o_grad) + + p_grad = torch.matmul(o_grad, v_expand.transpose(-1, -2)) + s_grad, odo = softmax_backward_odo(p, p_grad, self.o, o_grad, self.softmax_scale) + self.odo = odo + # s_grad = softmax_backward(p, p_grad, self.softmax_scale) + k_grad = torch.matmul(s_grad.transpose(-1, -2), self.q) + q_grad = torch.matmul(s_grad, k_expand) + k_grad = num_head_reduce(k_grad, self.k) + v_grad = num_head_reduce(v_grad, self.v) + return (q_grad, k_grad, v_grad, p_grad, s_grad) + +class ptSDPA(nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: float = None): + dtype = q.dtype + with sdpa_kernel(backends=[SDPBackend.MATH]): + return F.scaled_dot_product_attention(q, k, v, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=True) + +def set_dict(dump_dict, name, value): + if value.dtype == torch.bfloat16 or value.dtype == torch.float16: + dump_dict[name] = value.detach().clone().view(torch.uint16).numpy() + elif value.dtype == torch.bool: + dump_dict[name] = value.detach().clone().to(torch.uint16).numpy() + else: + dump_dict[name] = value.detach().clone().numpy() + +def test_sdpa(dtype, + seed: int, + batch: int, + num_heads_q: int, + num_heads_kv: int, + seq_len_qo: int, + seq_len_kv: int, + head_size_qk: int, + head_size_vo: int, + dropout_p: float = 0.0, + is_causal: bool = False, + is_bhsd: bool = True): + torch.manual_seed(seed) + q = torch.randn(batch, num_heads_q, seq_len_qo, head_size_qk, requires_grad=True).to(dtype) + k = torch.randn(batch, num_heads_kv, seq_len_kv, head_size_qk, requires_grad=True).to(dtype) + v = torch.randn(batch, num_heads_kv, seq_len_kv, head_size_vo, requires_grad=True).to(dtype) + q2 = q.clone() + k2 = k.clone() + v2 = v.clone() + q.retain_grad() + k.retain_grad() + v.retain_grad() + q2.retain_grad() + k2.retain_grad() + v2.retain_grad() + test_model = SDPA(dropout_p).to(dtype) + refe_model = ptSDPA().to(dtype) + + torch.manual_seed(seed) + attn_out = test_model(q, k, v, dropout_p, is_causal) + torch.manual_seed(seed) + attn_out_pt = refe_model(q2, k2, v2, dropout_p, is_causal) + grad = torch.empty_like(attn_out) + torch.manual_seed(seed) + grad.uniform_(-1, 1) + grad = grad.to(dtype) + attn_out.backward(grad) + attn_out_pt.backward(grad) + q_grad, k_grad, v_grad, p_grad, s_grad = test_model.backward_ref(grad) + dump_dict = {} + print(f"seed {seed} bsz {batch} nh_q {num_heads_q} nh_kv {num_heads_kv} sl_qo {seq_len_qo} sl_kv {seq_len_kv} hs_qk {head_size_qk} hs_vo {head_size_vo} dp {dropout_p} is_causal {is_causal} is_bhsd {is_bhsd}") + print('attn_out ', is_close(attn_out, attn_out_pt)) + print('p_grad ', is_close(GRAD_DICT['p_grad'], p_grad)) + # print('s2_grad ', is_close(GRAD_DICT['s2_grad'], s2_grad)) + print('s_grad ', is_close(GRAD_DICT['s_grad'], s_grad)) + print('k_grad ', is_close(k_grad, k2.grad)) + print('q_grad ', is_close(q_grad, q2.grad)) + print('v_grad ', is_close(v_grad, v2.grad)) + if is_bhsd: + set_dict(dump_dict, 'out', attn_out) + set_dict(dump_dict, 'grad', grad) + set_dict(dump_dict, 'v_grad', v_grad) + set_dict(dump_dict, 'k_grad', k_grad) + set_dict(dump_dict, 'q_grad', q_grad) + set_dict(dump_dict, 'q', q) + set_dict(dump_dict, 'k', k) + set_dict(dump_dict, 'v', v) + else: + set_dict(dump_dict, 'out', attn_out.transpose(1, 2).contiguous()) + set_dict(dump_dict, 'grad', grad.transpose(1, 2).contiguous()) + set_dict(dump_dict, 'v_grad', v_grad.transpose(1, 2).contiguous()) + set_dict(dump_dict, 'k_grad', k_grad.transpose(1, 2).contiguous()) + set_dict(dump_dict, 'q_grad', q_grad.transpose(1, 2).contiguous()) + set_dict(dump_dict, 'q', q.transpose(1, 2).contiguous()) + set_dict(dump_dict, 'k', k.transpose(1, 2).contiguous()) + set_dict(dump_dict, 'v', v.transpose(1, 2).contiguous()) + set_dict(dump_dict, 'lse', test_model.lse) + set_dict(dump_dict, 'odo', test_model.odo) + set_dict(dump_dict, 's', test_model.s) + set_dict(dump_dict, 'p', test_model.p) + set_dict(dump_dict, 'p_grad', p_grad) + set_dict(dump_dict, 's_grad', s_grad) + shape = np.array([batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, head_size_qk, head_size_vo, is_causal, is_bhsd], dtype=np.int32) + dump_dict['shape'] = shape + # print('test', v_grad[0,0:4,0,0:16]) + # print('upstream', v2.grad[0,0:4,0,0:16]) + np.savez(f'mha-{batch}-{num_heads_q}-{num_heads_kv}-{seq_len_qo}-{seq_len_kv}-{head_size_qk}-{head_size_vo}-{dropout_p}-{int(is_causal)}-{int(is_bhsd)}.npz', **dump_dict) + +def loop_run(): + global GRAD_DICT + for h in [4]: + # for seq_q in list(range(512, 512+32)): + # for seq_k in list(range(512, 512+32)): + for seq_q in [512, 513, 523, 527, 528, 529, 543]: + for seq_k in [512, 513, 523, 527, 528, 529, 543]: + for dim in [96]: + # print('test_run', 4, 4, h, seq_q, seq_k, dim, dim) + # bhsd + test_sdpa(torch.float16, 123, 4, 4, h, seq_q, seq_k, dim, dim, is_bhsd = True) + GRAD_DICT = {} + # bshd + test_sdpa(torch.float16, 123, 4, 4, h, seq_q, seq_k, dim, dim, is_bhsd = False) + GRAD_DICT = {} + +if __name__ == '__main__': + # test_sdpa(torch.bfloat16, 123, 128, 4, 4, 900, 900, 128, 128) + loop_run() + # test_sdpa(torch.float16, 123, 4, 4, 4, 513, 784, 128, 128, is_causal=True) + # GRAD_DICT = {} + # test_sdpa(torch.float16, 123, 4, 4, 4, 513, 784, 128, 128, is_causal=False) + # GRAD_DICT = {} + # test_sdpa(torch.float16, 123, 4, 4, 2, 513, 784, 128, 128, is_causal=False) + # GRAD_DICT = {} + # test_sdpa(torch.float16, 123, 4, 4, 2, 513, 784, 128, 128, is_causal=True) + # GRAD_DICT = {} + # test_sdpa(torch.float16, 123, 4, 4, 1, 513, 784, 128, 128, is_causal=False) + # GRAD_DICT = {} + # test_sdpa(torch.float16, 123, 4, 4, 1, 513, 784, 128, 128, is_causal=True) + # GRAD_DICT = {} + # test_sdpa(torch.bfloat16, 123, 4, 4, 2, 513, 513, 128, 64, is_causal=True) + # GRAD_DICT = {} + # test_sdpa(torch.bfloat16, 123, 4, 4, 2, 513, 513, 128, 64, 0.3, is_causal=False) + # GRAD_DICT = {} + # test_sdpa(torch.bfloat16, 123, 4, 4, 2, 513, 513, 128, 64, is_causal=False) + # GRAD_DICT = {} + # GRAD_DICT = {} + # test_sdpa(torch.bfloat16, 123, 4, 4, 4, 513, 513, 128, 64, is_causal=False) + # test_sdpa(torch.bfloat16, 123, 4, 4, 1, 513, 513, 128, 64, False) + # test_sdpa(torch.bfloat16, 123, 4, 4, 4, 1024, 513, 128, 128) + # test_sdpa(123, 2, 16, 1, 513, 513, 128, 128)