diff --git a/db/migrations/20220113043033_add_unique_constraint_to_users.cr b/db/migrations/20220113043033_add_unique_constraint_to_users.cr new file mode 100644 index 000000000..421b2e9f9 --- /dev/null +++ b/db/migrations/20220113043033_add_unique_constraint_to_users.cr @@ -0,0 +1,9 @@ +class AddUniqueConstraintToUsers::V20220113043033 < Avram::Migrator::Migration::V1 + def migrate + create_index :users, [:name, :nickname], unique: true + end + + def rollback + drop_index :users, [:name, :nickname] + end +end diff --git a/spec/avram/operations/save_operation_spec.cr b/spec/avram/operations/save_operation_spec.cr index 44e2e769f..c5113f928 100644 --- a/spec/avram/operations/save_operation_spec.cr +++ b/spec/avram/operations/save_operation_spec.cr @@ -66,6 +66,13 @@ private class ParamKeySaveOperation < ValueColumnModel::SaveOperation end private class UpsertUserOperation < User::SaveOperation + include QuerySpy + + upsert_lookup_columns :name, :nickname + upsert_unique_on :name, :nickname +end + +private class UpsertWithoutUniqueKeys < User::SaveOperation upsert_lookup_columns :name, :nickname end @@ -307,6 +314,147 @@ describe "Avram::SaveOperation" do end end + describe ".upsert" do + it "should only proc one query" do + UpsertUserOperation.times_called = 0 + some_time = Time.utc(2016, 2, 15, 10, 20, 30) + + updates = [ + { + name: "Name 1", + nickname: "Nickname 1", + age: 42, + joined_at: some_time, + created_at: some_time, + updated_at: some_time, + }, + { + name: "Name 2", + nickname: "Nickname 2", + age: 42, + joined_at: some_time, + created_at: some_time, + updated_at: some_time, + }, + ] + + UpsertUserOperation.upsert(updates) + UpsertUserOperation.times_called.should eq 1 + end + + context "when a record already exists" do + before_each do + UserFactory.create do |u| + u.name("Name 1") + u.nickname("Nickname 1") + u.age(42) + u.year_born(1960) + u.joined_at(Time.utc) + end + end + + it "allows manual passing of updated_at, but ignores created_at" do + some_time = Time.utc(2016, 2, 15, 10, 20, 30) + + update = { + name: "Name 1", + nickname: "Nickname 1", + age: 42, + joined_at: some_time, + created_at: some_time, + updated_at: some_time, + } + + records = UpsertUserOperation.upsert([update]) + records.first.created_at.should_not eq some_time + records.first.updated_at.should eq some_time + end + + it "should create one, and update the other record" do + update = { + name: "Name 1", + nickname: "Nickname 1", + year_born: nil, + age: 42, + joined_at: Time.utc, + } + + insert = { + name: "Name 2", + nickname: "Nickname 2", + year_born: 1980_i16, + age: 64, + joined_at: Time.utc, + } + + records = UpsertUserOperation.upsert([update, insert]) + + records.first.id.should_not eq nil + records.last.id.should_not eq nil + records.first.year_born.should eq nil + records.last.year_born.should eq 1980_i16 + end + end + + context "when no records exist" do + it "allows manual passing of id" do + insert = { + id: 42_i64, + name: "Name 1", + nickname: "Nickname 1", + age: 42, + joined_at: Time.utc, + } + + records = UpsertUserOperation.upsert([insert]) + records.first.id.should eq 42_i64 + end + + it "allows manual passing of updated_at and created_at" do + some_time = Time.utc(2016, 2, 15, 10, 20, 30) + + insert = { + name: "Name 1", + nickname: "Nickname 1", + age: 42, + joined_at: some_time, + created_at: some_time, + updated_at: some_time, + } + + records = UpsertUserOperation.upsert([insert]) + records.first.id.should_not eq nil + records.first.created_at.should eq some_time + records.first.updated_at.should eq some_time + end + end + + context "when the tuple values are passed in different orders" do + it "should upsert records" do + record_args = [ + { + name: "Name 1", + nickname: "Nickname 1", + year_born: nil, + age: 42, + joined_at: Time.utc, + }, + { + nickname: "Nickname 2", + name: "Name 2", + age: 42, + joined_at: Time.utc, + year_born: nil, + }, + ] + + records = UpsertUserOperation.upsert(record_args) + records.last.nickname.should eq "Nickname 2" + records.last.name.should eq "Name 2" + end + end + end + describe "#errors" do it "includes errors for all operation attributes" do operation = SaveUser.new diff --git a/spec/avram/view_spec.cr b/spec/avram/view_spec.cr index 52e5c9e79..1941a63b8 100644 --- a/spec/avram/view_spec.cr +++ b/spec/avram/view_spec.cr @@ -12,9 +12,9 @@ describe "views" do end it "works without a primary key" do - UserFactory.new.nickname("Johnny").create - UserFactory.new.nickname("Johnny").create - UserFactory.new.nickname("Johnny").create + UserFactory.new.name("P1").nickname("Johnny").create + UserFactory.new.name("P2").nickname("Johnny").create + UserFactory.new.name("P3").nickname("Johnny").create nickname_info = NicknameInfo::BaseQuery.first nickname_info.nickname.should eq "Johnny" diff --git a/src/avram/bulk_upsert.cr b/src/avram/bulk_upsert.cr new file mode 100644 index 000000000..7c96b23ad --- /dev/null +++ b/src/avram/bulk_upsert.cr @@ -0,0 +1,89 @@ +class Avram::BulkUpsert(T) + @column_types : Hash(String, String) + @permitted_fields : Array(Symbol) + + def initialize(@records : Array(T), + @conflicts : Array(Symbol), + permitted_fields : Array(Symbol)) + set_timestamps + @sample_record = @records.first.as(T) + @permitted_fields = permitted_fields_for(permitted_fields) + + @column_types = T.database_table_info.columns.map do |col_info| + { + col_info.column_name, + col_info.data_type, + } + end.to_h + end + + def statement + <<-SQL + INSERT INTO #{table}(#{fields}) + (SELECT * FROM unnest(#{value_placeholders})) + ON CONFLICT (#{conflicts}) DO UPDATE SET #{updates} + RETURNING #{returning} + SQL + end + + def args + @records.map do |record| + permitted_attributes(record).map(&.value) + end.transpose + end + + private def permitted_fields_for(fields : Array(Symbol)) + fields.push(:created_at) if @sample_record.responds_to?(:created_at) + fields.push(:updated_at) if @sample_record.responds_to?(:updated_at) + fields.uniq! + end + + private def permitted_attributes(record) + record + .attributes + .select { |attr| @permitted_fields.includes?(attr.name) } + end + + private def permitted_attributes + permitted_attributes(@sample_record) + end + + private def conflicts + @conflicts.join(", ") + end + + private def set_timestamps + @records.each do |record| + record.created_at.value ||= Time.utc if record.responds_to?(:created_at) + record.updated_at.value ||= Time.utc if record.responds_to?(:updated_at) + end + end + + private def table + @sample_record.table_name + end + + private def updates + (permitted_attribute_column_names - [:created_at]).compact_map do |column| + "#{column}=EXCLUDED.#{column}" + end.join(", ") + end + + private def returning + T.column_names.join(", ") + end + + private def permitted_attribute_column_names + permitted_attributes.map(&.name) + end + + private def fields + permitted_attribute_column_names.map(&.to_s).join(", ") + end + + private def value_placeholders + permitted_attributes.map_with_index(1) do |column, index| + "$#{index}::#{@column_types[column.name.to_s]}[]" + end.join(", ") + end +end diff --git a/src/avram/save_operation.cr b/src/avram/save_operation.cr index e07a152e4..b2754f981 100644 --- a/src/avram/save_operation.cr +++ b/src/avram/save_operation.cr @@ -379,6 +379,14 @@ abstract class Avram::SaveOperation(T) @record.try &.id end + def self.column_names + T.column_names + end + + def self.database_table_info + T.database_table_info.not_nil! + end + def before_save; end def after_save(_record : T); end diff --git a/src/avram/upsert.cr b/src/avram/upsert.cr index 9d1ae5815..4b7a1c017 100644 --- a/src/avram/upsert.cr +++ b/src/avram/upsert.cr @@ -90,6 +90,29 @@ module Avram::Upsert end end + macro upsert_unique_on(*attribute_names) + def self.upsert(upserts : Array(X)) forall X + \{% + if X > NamedTuple + raise("All array elements for #{@type}.upsert must be NamedTuples. You provided: #{X}") + elsif X.union? + keys = X.union_types.map(&.keys).join(", ") + raise("All tuples for #{@type}.upsert must have the same keys. Given: " + keys) + end + %} + + upsert = Avram::BulkUpsert(self).new( + records: upserts.map { |upsert_args| new(**upsert_args) }, + conflicts: {{ attribute_names }}.to_a, + permitted_fields: upserts.first.keys.to_a + ) + + new.database.query upsert.statement, args: upsert.args do |rs| + T.from_rs(rs) + end + end + end + # :nodoc: macro included {% for method in ["upsert", "upsert!"] %} @@ -100,5 +123,9 @@ module Avram::Upsert \{% raise "Please use the 'upsert_lookup_columns' macro in #{@type} before using '{{ method.id }}'" %} end {% end %} + + def self.upsert(_upserts : Array) + \{% raise "Please use the 'upsert_unique_on' macro in #{@type} before using '.upsert'" %} + end end end diff --git a/src/ext/db/param.cr b/src/ext/db/param.cr new file mode 100644 index 000000000..436215538 --- /dev/null +++ b/src/ext/db/param.cr @@ -0,0 +1,8 @@ +# Can be removed once https://github.com/will/crystal-pg/pull/244 is merged. +module PQ + record Param, slice : Slice(UInt8), size : Int32, format : Int16 do + def self.encode_array(io, value : Nil) + io << "NULL" + end + end +end